Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Line
Count
Source (jump to first uncovered line)
1
//===- InstCombineShifts.cpp ----------------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the visitShl, visitLShr, and visitAShr functions.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "InstCombineInternal.h"
14
#include "llvm/Analysis/InstructionSimplify.h"
15
#include "llvm/IR/IntrinsicInst.h"
16
#include "llvm/IR/PatternMatch.h"
17
#include "llvm/Transforms/InstCombine/InstCombiner.h"
18
using namespace llvm;
19
using namespace PatternMatch;
20
21
#define DEBUG_TYPE "instcombine"
22
23
bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1,
24
2.92k
                                        Value *ShAmt1) {
25
  // We have two shift amounts from two different shifts. The types of those
26
  // shift amounts may not match. If that's the case let's bailout now..
27
2.92k
  if (ShAmt0->getType() != ShAmt1->getType())
28
157
    return false;
29
30
  // As input, we have the following pattern:
31
  //   Sh0 (Sh1 X, Q), K
32
  // We want to rewrite that as:
33
  //   Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
34
  // While we know that originally (Q+K) would not overflow
35
  // (because  2 * (N-1) u<= iN -1), we have looked past extensions of
36
  // shift amounts. so it may now overflow in smaller bitwidth.
37
  // To ensure that does not happen, we need to ensure that the total maximal
38
  // shift amount is still representable in that smaller bit width.
39
2.76k
  unsigned MaximalPossibleTotalShiftAmount =
40
2.76k
      (Sh0->getType()->getScalarSizeInBits() - 1) +
41
2.76k
      (Sh1->getType()->getScalarSizeInBits() - 1);
42
2.76k
  APInt MaximalRepresentableShiftAmount =
43
2.76k
      APInt::getAllOnes(ShAmt0->getType()->getScalarSizeInBits());
44
2.76k
  return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount);
45
2.92k
}
46
47
// Given pattern:
48
//   (x shiftopcode Q) shiftopcode K
49
// we should rewrite it as
50
//   x shiftopcode (Q+K)  iff (Q+K) u< bitwidth(x) and
51
//
52
// This is valid for any shift, but they must be identical, and we must be
53
// careful in case we have (zext(Q)+zext(K)) and look past extensions,
54
// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
55
//
56
// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
57
// pattern has any 2 right-shifts that sum to 1 less than original bit width.
58
Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts(
59
    BinaryOperator *Sh0, const SimplifyQuery &SQ,
60
52.7k
    bool AnalyzeForSignBitExtraction) {
61
  // Look for a shift of some instruction, ignore zext of shift amount if any.
62
52.7k
  Instruction *Sh0Op0;
63
52.7k
  Value *ShAmt0;
64
52.7k
  if (!match(Sh0,
65
52.7k
             m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0)))))
66
34.7k
    return nullptr;
67
68
  // If there is a truncation between the two shifts, we must make note of it
69
  // and look through it. The truncation imposes additional constraints on the
70
  // transform.
71
17.9k
  Instruction *Sh1;
72
17.9k
  Value *Trunc = nullptr;
73
17.9k
  match(Sh0Op0,
74
17.9k
        m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)),
75
17.9k
                    m_Instruction(Sh1)));
76
77
  // Inner shift: (x shiftopcode ShAmt1)
78
  // Like with other shift, ignore zext of shift amount if any.
79
17.9k
  Value *X, *ShAmt1;
80
17.9k
  if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1)))))
81
15.1k
    return nullptr;
82
83
  // Verify that it would be safe to try to add those two shift amounts.
84
2.81k
  if (!canTryToConstantAddTwoShiftAmounts(Sh0, ShAmt0, Sh1, ShAmt1))
85
157
    return nullptr;
86
87
  // We are only looking for signbit extraction if we have two right shifts.
88
2.65k
  bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
89
2.65k
                           match(Sh1, m_Shr(m_Value(), m_Value()));
90
  // ... and if it's not two right-shifts, we know the answer already.
91
2.65k
  if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
92
40
    return nullptr;
93
94
  // The shift opcodes must be identical, unless we are just checking whether
95
  // this pattern can be interpreted as a sign-bit-extraction.
96
2.61k
  Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
97
2.61k
  bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
98
2.61k
  if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
99
2.12k
    return nullptr;
100
101
  // If we saw truncation, we'll need to produce extra instruction,
102
  // and for that one of the operands of the shift must be one-use,
103
  // unless of course we don't actually plan to produce any instructions here.
104
485
  if (Trunc && !AnalyzeForSignBitExtraction &&
105
485
      !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
106
17
    return nullptr;
107
108
  // Can we fold (ShAmt0+ShAmt1) ?
109
468
  auto *NewShAmt = dyn_cast_or_null<Constant>(
110
468
      simplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false,
111
468
                      SQ.getWithInstruction(Sh0)));
112
468
  if (!NewShAmt)
113
318
    return nullptr; // Did not simplify.
114
150
  unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits();
115
150
  unsigned XBitWidth = X->getType()->getScalarSizeInBits();
116
  // Is the new shift amount smaller than the bit width of inner/new shift?
117
150
  if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
118
150
                                          APInt(NewShAmtBitWidth, XBitWidth))))
119
26
    return nullptr; // FIXME: could perform constant-folding.
120
121
  // If there was a truncation, and we have a right-shift, we can only fold if
122
  // we are left with the original sign bit. Likewise, if we were just checking
123
  // that this is a sighbit extraction, this is the place to check it.
124
  // FIXME: zero shift amount is also legal here, but we can't *easily* check
125
  // more than one predicate so it's not really worth it.
126
124
  if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
127
    // If it's not a sign bit extraction, then we're done.
128
5
    if (!match(NewShAmt,
129
5
               m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
130
5
                                  APInt(NewShAmtBitWidth, XBitWidth - 1))))
131
2
      return nullptr;
132
    // If it is, and that was the question, return the base value.
133
3
    if (AnalyzeForSignBitExtraction)
134
0
      return X;
135
3
  }
136
137
122
  assert(IdenticalShOpcodes && "Should not get here with different shifts.");
138
139
122
  if (NewShAmt->getType() != X->getType()) {
140
34
    NewShAmt = ConstantFoldCastOperand(Instruction::ZExt, NewShAmt,
141
34
                                       X->getType(), SQ.DL);
142
34
    if (!NewShAmt)
143
0
      return nullptr;
144
34
  }
145
146
  // All good, we can do this fold.
147
122
  BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt);
148
149
  // The flags can only be propagated if there wasn't a trunc.
150
122
  if (!Trunc) {
151
    // If the pattern did not involve trunc, and both of the original shifts
152
    // had the same flag set, preserve the flag.
153
88
    if (ShiftOpcode == Instruction::BinaryOps::Shl) {
154
15
      NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() &&
155
15
                                     Sh1->hasNoUnsignedWrap());
156
15
      NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() &&
157
15
                                   Sh1->hasNoSignedWrap());
158
73
    } else {
159
73
      NewShift->setIsExact(Sh0->isExact() && Sh1->isExact());
160
73
    }
161
88
  }
162
163
122
  Instruction *Ret = NewShift;
164
122
  if (Trunc) {
165
34
    Builder.Insert(NewShift);
166
34
    Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType());
167
34
  }
168
169
122
  return Ret;
170
122
}
171
172
// If we have some pattern that leaves only some low bits set, and then performs
173
// left-shift of those bits, if none of the bits that are left after the final
174
// shift are modified by the mask, we can omit the mask.
175
//
176
// There are many variants to this pattern:
177
//   a)  (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
178
//   b)  (x & (~(-1 << MaskShAmt))) << ShiftShAmt
179
//   c)  (x & (-1 l>> MaskShAmt)) << ShiftShAmt
180
//   d)  (x & ((-1 << MaskShAmt) l>> MaskShAmt)) << ShiftShAmt
181
//   e)  ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt
182
//   f)  ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt
183
// All these patterns can be simplified to just:
184
//   x << ShiftShAmt
185
// iff:
186
//   a,b)     (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
187
//   c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
188
static Instruction *
189
dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
190
                                     const SimplifyQuery &Q,
191
19.3k
                                     InstCombiner::BuilderTy &Builder) {
192
19.3k
  assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl &&
193
19.3k
         "The input must be 'shl'!");
194
195
0
  Value *Masked, *ShiftShAmt;
196
19.3k
  match(OuterShift,
197
19.3k
        m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt))));
198
199
  // *If* there is a truncation between an outer shift and a possibly-mask,
200
  // then said truncation *must* be one-use, else we can't perform the fold.
201
19.3k
  Value *Trunc;
202
19.3k
  if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) &&
203
19.3k
      !Trunc->hasOneUse())
204
99
    return nullptr;
205
206
19.2k
  Type *NarrowestTy = OuterShift->getType();
207
19.2k
  Type *WidestTy = Masked->getType();
208
19.2k
  bool HadTrunc = WidestTy != NarrowestTy;
209
210
  // The mask must be computed in a type twice as wide to ensure
211
  // that no bits are lost if the sum-of-shifts is wider than the base type.
212
19.2k
  Type *ExtendedTy = WidestTy->getExtendedType();
213
214
19.2k
  Value *MaskShAmt;
215
216
  // ((1 << MaskShAmt) - 1)
217
19.2k
  auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
218
  // (~(-1 << maskNbits))
219
19.2k
  auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
220
  // (-1 l>> MaskShAmt)
221
19.2k
  auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt));
222
  // ((-1 << MaskShAmt) l>> MaskShAmt)
223
19.2k
  auto MaskD =
224
19.2k
      m_LShr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt));
225
226
19.2k
  Value *X;
227
19.2k
  Constant *NewMask;
228
229
19.2k
  if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
230
    // Peek through an optional zext of the shift amount.
231
33
    match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
232
233
    // Verify that it would be safe to try to add those two shift amounts.
234
33
    if (!canTryToConstantAddTwoShiftAmounts(OuterShift, ShiftShAmt, Masked,
235
33
                                            MaskShAmt))
236
5
      return nullptr;
237
238
    // Can we simplify (MaskShAmt+ShiftShAmt) ?
239
28
    auto *SumOfShAmts = dyn_cast_or_null<Constant>(simplifyAddInst(
240
28
        MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
241
28
    if (!SumOfShAmts)
242
5
      return nullptr; // Did not simplify.
243
    // In this pattern SumOfShAmts correlates with the number of low bits
244
    // that shall remain in the root value (OuterShift).
245
246
    // An extend of an undef value becomes zero because the high bits are never
247
    // completely unknown. Replace the `undef` shift amounts with final
248
    // shift bitwidth to ensure that the value remains undef when creating the
249
    // subsequent shift op.
250
23
    SumOfShAmts = Constant::replaceUndefsWith(
251
23
        SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(),
252
23
                                      ExtendedTy->getScalarSizeInBits()));
253
23
    auto *ExtendedSumOfShAmts = ConstantFoldCastOperand(
254
23
        Instruction::ZExt, SumOfShAmts, ExtendedTy, Q.DL);
255
23
    if (!ExtendedSumOfShAmts)
256
0
      return nullptr;
257
258
    // And compute the mask as usual: ~(-1 << (SumOfShAmts))
259
23
    auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
260
23
    auto *ExtendedInvertedMask =
261
23
        ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts);
262
23
    NewMask = ConstantExpr::getNot(ExtendedInvertedMask);
263
19.2k
  } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) ||
264
19.2k
             match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)),
265
19.2k
                                 m_Deferred(MaskShAmt)))) {
266
    // Peek through an optional zext of the shift amount.
267
76
    match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt)));
268
269
    // Verify that it would be safe to try to add those two shift amounts.
270
76
    if (!canTryToConstantAddTwoShiftAmounts(OuterShift, ShiftShAmt, Masked,
271
76
                                            MaskShAmt))
272
5
      return nullptr;
273
274
    // Can we simplify (ShiftShAmt-MaskShAmt) ?
275
71
    auto *ShAmtsDiff = dyn_cast_or_null<Constant>(simplifySubInst(
276
71
        ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q));
277
71
    if (!ShAmtsDiff)
278
25
      return nullptr; // Did not simplify.
279
    // In this pattern ShAmtsDiff correlates with the number of high bits that
280
    // shall be unset in the root value (OuterShift).
281
282
    // An extend of an undef value becomes zero because the high bits are never
283
    // completely unknown. Replace the `undef` shift amounts with negated
284
    // bitwidth of innermost shift to ensure that the value remains undef when
285
    // creating the subsequent shift op.
286
46
    unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits();
287
46
    ShAmtsDiff = Constant::replaceUndefsWith(
288
46
        ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(),
289
46
                                     -WidestTyBitWidth));
290
46
    auto *ExtendedNumHighBitsToClear = ConstantFoldCastOperand(
291
46
        Instruction::ZExt,
292
46
        ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(),
293
46
                                              WidestTyBitWidth,
294
46
                                              /*isSigned=*/false),
295
46
                             ShAmtsDiff),
296
46
        ExtendedTy, Q.DL);
297
46
    if (!ExtendedNumHighBitsToClear)
298
0
      return nullptr;
299
300
    // And compute the mask as usual: (-1 l>> (NumHighBitsToClear))
301
46
    auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy);
302
46
    NewMask = ConstantFoldBinaryOpOperands(Instruction::LShr, ExtendedAllOnes,
303
46
                                           ExtendedNumHighBitsToClear, Q.DL);
304
46
    if (!NewMask)
305
0
      return nullptr;
306
46
  } else
307
19.1k
    return nullptr; // Don't know anything about this pattern.
308
309
69
  NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy);
310
311
  // Does this mask has any unset bits? If not then we can just not apply it.
312
69
  bool NeedMask = !match(NewMask, m_AllOnes());
313
314
  // If we need to apply a mask, there are several more restrictions we have.
315
69
  if (NeedMask) {
316
    // The old masking instruction must go away.
317
58
    if (!Masked->hasOneUse())
318
30
      return nullptr;
319
    // The original "masking" instruction must not have been`ashr`.
320
28
    if (match(Masked, m_AShr(m_Value(), m_Value())))
321
5
      return nullptr;
322
28
  }
323
324
  // If we need to apply truncation, let's do it first, since we can.
325
  // We have already ensured that the old truncation will go away.
326
34
  if (HadTrunc)
327
11
    X = Builder.CreateTrunc(X, NarrowestTy);
328
329
  // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits.
330
  // We didn't change the Type of this outermost shift, so we can just do it.
331
34
  auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X,
332
34
                                          OuterShift->getOperand(1));
333
34
  if (!NeedMask)
334
11
    return NewShift;
335
336
23
  Builder.Insert(NewShift);
337
23
  return BinaryOperator::Create(Instruction::And, NewShift, NewMask);
338
34
}
339
340
/// If we have a shift-by-constant of a bin op (bitwise logic op or add/sub w/
341
/// shl) that itself has a shift-by-constant operand with identical opcode, we
342
/// may be able to convert that into 2 independent shifts followed by the logic
343
/// op. This eliminates a use of an intermediate value (reduces dependency
344
/// chain).
345
static Instruction *foldShiftOfShiftedBinOp(BinaryOperator &I,
346
45.8k
                                            InstCombiner::BuilderTy &Builder) {
347
45.8k
  assert(I.isShift() && "Expected a shift as input");
348
0
  auto *BinInst = dyn_cast<BinaryOperator>(I.getOperand(0));
349
45.8k
  if (!BinInst ||
350
45.8k
      (!BinInst->isBitwiseLogicOp() &&
351
10.1k
       BinInst->getOpcode() != Instruction::Add &&
352
10.1k
       BinInst->getOpcode() != Instruction::Sub) ||
353
45.8k
      !BinInst->hasOneUse())
354
43.8k
    return nullptr;
355
356
2.00k
  Constant *C0, *C1;
357
2.00k
  if (!match(I.getOperand(1), m_Constant(C1)))
358
486
    return nullptr;
359
360
1.51k
  Instruction::BinaryOps ShiftOpcode = I.getOpcode();
361
  // Transform for add/sub only works with shl.
362
1.51k
  if ((BinInst->getOpcode() == Instruction::Add ||
363
1.51k
       BinInst->getOpcode() == Instruction::Sub) &&
364
1.51k
      ShiftOpcode != Instruction::Shl)
365
474
    return nullptr;
366
367
1.04k
  Type *Ty = I.getType();
368
369
  // Find a matching one-use shift by constant. The fold is not valid if the sum
370
  // of the shift values equals or exceeds bitwidth.
371
  // TODO: Remove the one-use check if the other logic operand (Y) is constant.
372
1.04k
  Value *X, *Y;
373
2.03k
  auto matchFirstShift = [&](Value *V) {
374
2.03k
    APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits());
375
2.03k
    return match(V,
376
2.03k
                 m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) &&
377
2.03k
           match(ConstantExpr::getAdd(C0, C1),
378
137
                 m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
379
2.03k
  };
380
381
  // Logic ops and Add are commutative, so check each operand for a match. Sub
382
  // is not so we cannot reoder if we match operand(1) and need to keep the
383
  // operands in their original positions.
384
1.04k
  bool FirstShiftIsOp1 = false;
385
1.04k
  if (matchFirstShift(BinInst->getOperand(0)))
386
55
    Y = BinInst->getOperand(1);
387
990
  else if (matchFirstShift(BinInst->getOperand(1))) {
388
60
    Y = BinInst->getOperand(0);
389
60
    FirstShiftIsOp1 = BinInst->getOpcode() == Instruction::Sub;
390
60
  } else
391
930
    return nullptr;
392
393
  // shift (binop (shift X, C0), Y), C1 -> binop (shift X, C0+C1), (shift Y, C1)
394
115
  Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
395
115
  Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
396
115
  Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1);
397
115
  Value *Op1 = FirstShiftIsOp1 ? NewShift2 : NewShift1;
398
115
  Value *Op2 = FirstShiftIsOp1 ? NewShift1 : NewShift2;
399
115
  return BinaryOperator::Create(BinInst->getOpcode(), Op1, Op2);
400
1.04k
}
401
402
47.8k
Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
403
47.8k
  if (Instruction *Phi = foldBinopWithPhiOperands(I))
404
0
    return Phi;
405
406
47.8k
  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
407
47.8k
  assert(Op0->getType() == Op1->getType());
408
0
  Type *Ty = I.getType();
409
410
  // If the shift amount is a one-use `sext`, we can demote it to `zext`.
411
47.8k
  Value *Y;
412
47.8k
  if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) {
413
15
    Value *NewExt = Builder.CreateZExt(Y, Ty, Op1->getName());
414
15
    return BinaryOperator::Create(I.getOpcode(), Op0, NewExt);
415
15
  }
416
417
  // See if we can fold away this shift.
418
47.8k
  if (SimplifyDemandedInstructionBits(I))
419
827
    return &I;
420
421
  // Try to fold constant and into select arguments.
422
47.0k
  if (isa<Constant>(Op0))
423
5.35k
    if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
424
79
      if (Instruction *R = FoldOpIntoSelect(I, SI))
425
32
        return R;
426
427
46.9k
  if (Constant *CUI = dyn_cast<Constant>(Op1))
428
32.0k
    if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
429
973
      return Res;
430
431
46.0k
  if (auto *NewShift = cast_or_null<Instruction>(
432
46.0k
          reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
433
122
    return NewShift;
434
435
  // Pre-shift a constant shifted by a variable amount with constant offset:
436
  // C shift (A add nuw C1) --> (C shift C1) shift A
437
45.8k
  Value *A;
438
45.8k
  Constant *C, *C1;
439
45.8k
  if (match(Op0, m_Constant(C)) &&
440
45.8k
      match(Op1, m_NUWAdd(m_Value(A), m_Constant(C1)))) {
441
25
    Value *NewC = Builder.CreateBinOp(I.getOpcode(), C, C1);
442
25
    return BinaryOperator::Create(I.getOpcode(), NewC, A);
443
25
  }
444
445
45.8k
  unsigned BitWidth = Ty->getScalarSizeInBits();
446
447
45.8k
  const APInt *AC, *AddC;
448
  // Try to pre-shift a constant shifted by a variable amount added with a
449
  // negative number:
450
  // C << (X - AddC) --> (C >> AddC) << X
451
  // and
452
  // C >> (X - AddC) --> (C << AddC) >> X
453
45.8k
  if (match(Op0, m_APInt(AC)) && match(Op1, m_Add(m_Value(A), m_APInt(AddC))) &&
454
45.8k
      AddC->isNegative() && (-*AddC).ult(BitWidth)) {
455
157
    assert(!AC->isZero() && "Expected simplify of shifted zero");
456
0
    unsigned PosOffset = (-*AddC).getZExtValue();
457
458
157
    auto isSuitableForPreShift = [PosOffset, &I, AC]() {
459
157
      switch (I.getOpcode()) {
460
0
      default:
461
0
        return false;
462
139
      case Instruction::Shl:
463
139
        return (I.hasNoSignedWrap() || I.hasNoUnsignedWrap()) &&
464
139
               AC->eq(AC->lshr(PosOffset).shl(PosOffset));
465
8
      case Instruction::LShr:
466
8
        return I.isExact() && AC->eq(AC->shl(PosOffset).lshr(PosOffset));
467
10
      case Instruction::AShr:
468
10
        return I.isExact() && AC->eq(AC->shl(PosOffset).ashr(PosOffset));
469
157
      }
470
157
    };
471
157
    if (isSuitableForPreShift()) {
472
0
      Constant *NewC = ConstantInt::get(Ty, I.getOpcode() == Instruction::Shl
473
0
                                                ? AC->lshr(PosOffset)
474
0
                                                : AC->shl(PosOffset));
475
0
      BinaryOperator *NewShiftOp =
476
0
          BinaryOperator::Create(I.getOpcode(), NewC, A);
477
0
      if (I.getOpcode() == Instruction::Shl) {
478
0
        NewShiftOp->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
479
0
      } else {
480
0
        NewShiftOp->setIsExact();
481
0
      }
482
0
      return NewShiftOp;
483
0
    }
484
157
  }
485
486
  // X shift (A srem C) -> X shift (A and (C - 1)) iff C is a power of 2.
487
  // Because shifts by negative values (which could occur if A were negative)
488
  // are undefined.
489
45.8k
  if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Constant(C))) &&
490
45.8k
      match(C, m_Power2())) {
491
    // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't
492
    // demand the sign bit (and many others) here??
493
11
    Constant *Mask = ConstantExpr::getSub(C, ConstantInt::get(Ty, 1));
494
11
    Value *Rem = Builder.CreateAnd(A, Mask, Op1->getName());
495
11
    return replaceOperand(I, 1, Rem);
496
11
  }
497
498
45.8k
  if (Instruction *Logic = foldShiftOfShiftedBinOp(I, Builder))
499
115
    return Logic;
500
501
45.7k
  if (match(Op1, m_Or(m_Value(), m_SpecificInt(BitWidth - 1))))
502
0
    return replaceOperand(I, 1, ConstantInt::get(Ty, BitWidth - 1));
503
504
45.7k
  return nullptr;
505
45.7k
}
506
507
/// Return true if we can simplify two logical (either left or right) shifts
508
/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
509
static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
510
                                    Instruction *InnerShift,
511
803
                                    InstCombinerImpl &IC, Instruction *CxtI) {
512
803
  assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
513
514
  // We need constant scalar or constant splat shifts.
515
0
  const APInt *InnerShiftConst;
516
803
  if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
517
108
    return false;
518
519
  // Two logical shifts in the same direction:
520
  // shl (shl X, C1), C2 -->  shl X, C1 + C2
521
  // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
522
695
  bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
523
695
  if (IsInnerShl == IsOuterShl)
524
169
    return true;
525
526
  // Equal shift amounts in opposite directions become bitwise 'and':
527
  // lshr (shl X, C), C --> and X, C'
528
  // shl (lshr X, C), C --> and X, C'
529
526
  if (*InnerShiftConst == OuterShAmt)
530
200
    return true;
531
532
  // If the 2nd shift is bigger than the 1st, we can fold:
533
  // lshr (shl X, C1), C2 -->  and (shl X, C1 - C2), C3
534
  // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
535
  // but it isn't profitable unless we know the and'd out bits are already zero.
536
  // Also, check that the inner shift is valid (less than the type width) or
537
  // we'll crash trying to produce the bit mask for the 'and'.
538
326
  unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
539
326
  if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
540
293
    unsigned InnerShAmt = InnerShiftConst->getZExtValue();
541
293
    unsigned MaskShift =
542
293
        IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
543
293
    APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
544
293
    if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
545
67
      return true;
546
293
  }
547
548
259
  return false;
549
326
}
550
551
/// See if we can compute the specified value, but shifted logically to the left
552
/// or right by some number of bits. This should return true if the expression
553
/// can be computed for the same cost as the current expression tree. This is
554
/// used to eliminate extraneous shifting from things like:
555
///      %C = shl i128 %A, 64
556
///      %D = shl i128 %B, 96
557
///      %E = or i128 %C, %D
558
///      %F = lshr i128 %E, 64
559
/// where the client will ask if E can be computed shifted right by 64-bits. If
560
/// this succeeds, getShiftedValue() will be called to produce the value.
561
static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
562
29.1k
                               InstCombinerImpl &IC, Instruction *CxtI) {
563
  // We can always evaluate immediate constants.
564
29.1k
  if (match(V, m_ImmConstant()))
565
164
    return true;
566
567
28.9k
  Instruction *I = dyn_cast<Instruction>(V);
568
28.9k
  if (!I) return false;
569
570
  // We can't mutate something that has multiple uses: doing so would
571
  // require duplicating the instruction in general, which isn't profitable.
572
12.0k
  if (!I->hasOneUse()) return false;
573
574
5.45k
  switch (I->getOpcode()) {
575
2.60k
  default: return false;
576
1.03k
  case Instruction::And:
577
1.61k
  case Instruction::Or:
578
1.73k
  case Instruction::Xor:
579
    // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
580
1.73k
    return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
581
1.73k
           canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
582
583
299
  case Instruction::Shl:
584
803
  case Instruction::LShr:
585
803
    return canEvaluateShiftedShift(NumBits, IsLeftShift, I, IC, CxtI);
586
587
90
  case Instruction::Select: {
588
90
    SelectInst *SI = cast<SelectInst>(I);
589
90
    Value *TrueVal = SI->getTrueValue();
590
90
    Value *FalseVal = SI->getFalseValue();
591
90
    return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
592
90
           canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
593
299
  }
594
14
  case Instruction::PHI: {
595
    // We can change a phi if we can change all operands.  Note that we never
596
    // get into trouble with cyclic PHIs here because we only consider
597
    // instructions with a single use.
598
14
    PHINode *PN = cast<PHINode>(I);
599
14
    for (Value *IncValue : PN->incoming_values())
600
16
      if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
601
14
        return false;
602
0
    return true;
603
14
  }
604
217
  case Instruction::Mul: {
605
217
    const APInt *MulConst;
606
    // We can fold (shr (mul X, -(1 << C)), C) -> (and (neg X), C`)
607
217
    return !IsLeftShift && match(I->getOperand(1), m_APInt(MulConst)) &&
608
217
           MulConst->isNegatedPowerOf2() && MulConst->countr_zero() == NumBits;
609
14
  }
610
5.45k
  }
611
5.45k
}
612
613
/// Fold OuterShift (InnerShift X, C1), C2.
614
/// See canEvaluateShiftedShift() for the constraints on these instructions.
615
static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
616
                               bool IsOuterShl,
617
275
                               InstCombiner::BuilderTy &Builder) {
618
275
  bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
619
275
  Type *ShType = InnerShift->getType();
620
275
  unsigned TypeWidth = ShType->getScalarSizeInBits();
621
622
  // We only accept shifts-by-a-constant in canEvaluateShifted().
623
275
  const APInt *C1;
624
275
  match(InnerShift->getOperand(1), m_APInt(C1));
625
275
  unsigned InnerShAmt = C1->getZExtValue();
626
627
  // Change the shift amount and clear the appropriate IR flags.
628
275
  auto NewInnerShift = [&](unsigned ShAmt) {
629
79
    InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
630
79
    if (IsInnerShl) {
631
32
      InnerShift->setHasNoUnsignedWrap(false);
632
32
      InnerShift->setHasNoSignedWrap(false);
633
47
    } else {
634
47
      InnerShift->setIsExact(false);
635
47
    }
636
79
    return InnerShift;
637
79
  };
638
639
  // Two logical shifts in the same direction:
640
  // shl (shl X, C1), C2 -->  shl X, C1 + C2
641
  // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
642
275
  if (IsInnerShl == IsOuterShl) {
643
    // If this is an oversized composite shift, then unsigned shifts get 0.
644
71
    if (InnerShAmt + OuterShAmt >= TypeWidth)
645
0
      return Constant::getNullValue(ShType);
646
647
71
    return NewInnerShift(InnerShAmt + OuterShAmt);
648
71
  }
649
650
  // Equal shift amounts in opposite directions become bitwise 'and':
651
  // lshr (shl X, C), C --> and X, C'
652
  // shl (lshr X, C), C --> and X, C'
653
204
  if (InnerShAmt == OuterShAmt) {
654
196
    APInt Mask = IsInnerShl
655
196
                     ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
656
196
                     : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
657
196
    Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
658
196
                                   ConstantInt::get(ShType, Mask));
659
196
    if (auto *AndI = dyn_cast<Instruction>(And)) {
660
196
      AndI->moveBefore(InnerShift);
661
196
      AndI->takeName(InnerShift);
662
196
    }
663
196
    return And;
664
196
  }
665
666
8
  assert(InnerShAmt > OuterShAmt &&
667
8
         "Unexpected opposite direction logical shift pair");
668
669
  // In general, we would need an 'and' for this transform, but
670
  // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
671
  // lshr (shl X, C1), C2 -->  shl X, C1 - C2
672
  // shl (lshr X, C1), C2 --> lshr X, C1 - C2
673
0
  return NewInnerShift(InnerShAmt - OuterShAmt);
674
204
}
675
676
/// When canEvaluateShifted() returns true for an expression, this function
677
/// inserts the new computation that produces the shifted value.
678
static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
679
440
                              InstCombinerImpl &IC, const DataLayout &DL) {
680
  // We can always evaluate constants shifted.
681
440
  if (Constant *C = dyn_cast<Constant>(V)) {
682
83
    if (isLeftShift)
683
65
      return IC.Builder.CreateShl(C, NumBits);
684
18
    else
685
18
      return IC.Builder.CreateLShr(C, NumBits);
686
83
  }
687
688
357
  Instruction *I = cast<Instruction>(V);
689
357
  IC.addToWorklist(I);
690
691
357
  switch (I->getOpcode()) {
692
0
  default: llvm_unreachable("Inconsistency with CanEvaluateShifted");
693
29
  case Instruction::And:
694
50
  case Instruction::Or:
695
52
  case Instruction::Xor:
696
    // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
697
52
    I->setOperand(
698
52
        0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
699
52
    I->setOperand(
700
52
        1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
701
52
    return I;
702
703
117
  case Instruction::Shl:
704
275
  case Instruction::LShr:
705
275
    return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
706
275
                            IC.Builder);
707
708
17
  case Instruction::Select:
709
17
    I->setOperand(
710
17
        1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
711
17
    I->setOperand(
712
17
        2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
713
17
    return I;
714
0
  case Instruction::PHI: {
715
    // We can change a phi if we can change all operands.  Note that we never
716
    // get into trouble with cyclic PHIs here because we only consider
717
    // instructions with a single use.
718
0
    PHINode *PN = cast<PHINode>(I);
719
0
    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
720
0
      PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
721
0
                                              isLeftShift, IC, DL));
722
0
    return PN;
723
117
  }
724
13
  case Instruction::Mul: {
725
13
    assert(!isLeftShift && "Unexpected shift direction!");
726
0
    auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0));
727
13
    IC.InsertNewInstWith(Neg, I->getIterator());
728
13
    unsigned TypeWidth = I->getType()->getScalarSizeInBits();
729
13
    APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits);
730
13
    auto *And = BinaryOperator::CreateAnd(Neg,
731
13
                                          ConstantInt::get(I->getType(), Mask));
732
13
    And->takeName(I);
733
13
    return IC.InsertNewInstWith(And, I->getIterator());
734
117
  }
735
357
  }
736
357
}
737
738
// If this is a bitwise operator or add with a constant RHS we might be able
739
// to pull it through a shift.
740
static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,
741
1.61k
                                         BinaryOperator *BO) {
742
1.61k
  switch (BO->getOpcode()) {
743
916
  default:
744
916
    return false; // Do not perform transform!
745
97
  case Instruction::Add:
746
97
    return Shift.getOpcode() == Instruction::Shl;
747
27
  case Instruction::Or:
748
535
  case Instruction::And:
749
535
    return true;
750
65
  case Instruction::Xor:
751
    // Do not change a 'not' of logical shift because that would create a normal
752
    // 'xor'. The 'not' is likely better for analysis, SCEV, and codegen.
753
65
    return !(Shift.isLogicalShift() && match(BO, m_Not(m_Value())));
754
1.61k
  }
755
1.61k
}
756
757
Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
758
32.0k
                                                   BinaryOperator &I) {
759
  // (C2 << X) << C1 --> (C2 << C1) << X
760
  // (C2 >> X) >> C1 --> (C2 >> C1) >> X
761
32.0k
  Constant *C2;
762
32.0k
  Value *X;
763
32.0k
  if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X))))
764
28
    return BinaryOperator::Create(
765
28
        I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X);
766
767
32.0k
  bool IsLeftShift = I.getOpcode() == Instruction::Shl;
768
32.0k
  Type *Ty = I.getType();
769
32.0k
  unsigned TypeBits = Ty->getScalarSizeInBits();
770
771
  // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC)
772
  // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC)
773
32.0k
  const APInt *DivC;
774
32.0k
  if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) &&
775
32.0k
      match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() &&
776
32.0k
      !DivC->isMinSignedValue()) {
777
0
    Constant *NegDivC = ConstantInt::get(Ty, -(*DivC));
778
0
    ICmpInst::Predicate Pred =
779
0
        DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE;
780
0
    Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC);
781
0
    auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt
782
0
                                                          : Instruction::ZExt;
783
0
    return CastInst::Create(ExtOpcode, Cmp, Ty);
784
0
  }
785
786
32.0k
  const APInt *Op1C;
787
32.0k
  if (!match(C1, m_APInt(Op1C)))
788
2.03k
    return nullptr;
789
790
30.0k
  assert(!Op1C->uge(TypeBits) &&
791
30.0k
         "Shift over the type width should have been removed already");
792
793
  // See if we can propagate this shift into the input, this covers the trivial
794
  // cast of lshr(shl(x,c1),c2) as well as other more complex cases.
795
30.0k
  if (I.getOpcode() != Instruction::AShr &&
796
30.0k
      canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
797
302
    LLVM_DEBUG(
798
302
        dbgs() << "ICE: GetShiftedValue propagating shift through expression"
799
302
                  " to eliminate shift:\n  IN: "
800
302
               << *Op0 << "\n  SH: " << I << "\n");
801
802
302
    return replaceInstUsesWith(
803
302
        I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL));
804
302
  }
805
806
29.7k
  if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
807
33
    return FoldedShift;
808
809
29.6k
  if (!Op0->hasOneUse())
810
15.8k
    return nullptr;
811
812
13.8k
  if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) {
813
    // If the operand is a bitwise operator with a constant RHS, and the
814
    // shift is the only use, we can pull it out of the shift.
815
3.13k
    const APInt *Op0C;
816
3.13k
    if (match(Op0BO->getOperand(1), m_APInt(Op0C))) {
817
1.56k
      if (canShiftBinOpWithConstantRHS(I, Op0BO)) {
818
566
        Value *NewRHS =
819
566
            Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(1), C1);
820
821
566
        Value *NewShift =
822
566
            Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), C1);
823
566
        NewShift->takeName(Op0BO);
824
825
566
        return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS);
826
566
      }
827
1.56k
    }
828
3.13k
  }
829
830
  // If we have a select that conditionally executes some binary operator,
831
  // see if we can pull it the select and operator through the shift.
832
  //
833
  // For example, turning:
834
  //   shl (select C, (add X, C1), X), C2
835
  // Into:
836
  //   Y = shl X, C2
837
  //   select C, (add Y, C1 << C2), Y
838
13.2k
  Value *Cond;
839
13.2k
  BinaryOperator *TBO;
840
13.2k
  Value *FalseVal;
841
13.2k
  if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)),
842
13.2k
                          m_Value(FalseVal)))) {
843
58
    const APInt *C;
844
58
    if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal &&
845
58
        match(TBO->getOperand(1), m_APInt(C)) &&
846
58
        canShiftBinOpWithConstantRHS(I, TBO)) {
847
43
      Value *NewRHS =
848
43
          Builder.CreateBinOp(I.getOpcode(), TBO->getOperand(1), C1);
849
850
43
      Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, C1);
851
43
      Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS);
852
43
      return SelectInst::Create(Cond, NewOp, NewShift);
853
43
    }
854
58
  }
855
856
13.2k
  BinaryOperator *FBO;
857
13.2k
  Value *TrueVal;
858
13.2k
  if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal),
859
13.2k
                          m_OneUse(m_BinOp(FBO))))) {
860
4
    const APInt *C;
861
4
    if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal &&
862
4
        match(FBO->getOperand(1), m_APInt(C)) &&
863
4
        canShiftBinOpWithConstantRHS(I, FBO)) {
864
1
      Value *NewRHS =
865
1
          Builder.CreateBinOp(I.getOpcode(), FBO->getOperand(1), C1);
866
867
1
      Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, C1);
868
1
      Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS);
869
1
      return SelectInst::Create(Cond, NewShift, NewOp);
870
1
    }
871
4
  }
872
873
13.2k
  return nullptr;
874
13.2k
}
875
876
// Tries to perform
877
//    (lshr (add (zext X), (zext Y)), K)
878
//      -> (icmp ult (add X, Y), X)
879
//    where
880
//      - The add's operands are zexts from a K-bits integer to a bigger type.
881
//      - The add is only used by the shr, or by iK (or narrower) truncates.
882
//      - The lshr type has more than 2 bits (other types are boolean math).
883
//      - K > 1
884
//    note that
885
//      - The resulting add cannot have nuw/nsw, else on overflow we get a
886
//        poison value and the transform isn't legal anymore.
887
20.2k
Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
888
20.2k
  assert(I.getOpcode() == Instruction::LShr);
889
890
0
  Value *Add = I.getOperand(0);
891
20.2k
  Value *ShiftAmt = I.getOperand(1);
892
20.2k
  Type *Ty = I.getType();
893
894
20.2k
  if (Ty->getScalarSizeInBits() < 3)
895
0
    return nullptr;
896
897
20.2k
  const APInt *ShAmtAPInt = nullptr;
898
20.2k
  Value *X = nullptr, *Y = nullptr;
899
20.2k
  if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) ||
900
20.2k
      !match(Add,
901
14.3k
             m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y))))))
902
20.2k
    return nullptr;
903
904
10
  const unsigned ShAmt = ShAmtAPInt->getZExtValue();
905
10
  if (ShAmt == 1)
906
10
    return nullptr;
907
908
  // X/Y are zexts from `ShAmt`-sized ints.
909
0
  if (X->getType()->getScalarSizeInBits() != ShAmt ||
910
0
      Y->getType()->getScalarSizeInBits() != ShAmt)
911
0
    return nullptr;
912
913
  // Make sure that `Add` is only used by `I` and `ShAmt`-truncates.
914
0
  if (!Add->hasOneUse()) {
915
0
    for (User *U : Add->users()) {
916
0
      if (U == &I)
917
0
        continue;
918
919
0
      TruncInst *Trunc = dyn_cast<TruncInst>(U);
920
0
      if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt)
921
0
        return nullptr;
922
0
    }
923
0
  }
924
925
  // Insert at Add so that the newly created `NarrowAdd` will dominate it's
926
  // users (i.e. `Add`'s users).
927
0
  Instruction *AddInst = cast<Instruction>(Add);
928
0
  Builder.SetInsertPoint(AddInst);
929
930
0
  Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed");
931
0
  Value *Overflow =
932
0
      Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow");
933
934
  // Replace the uses of the original add with a zext of the
935
  // NarrowAdd's result. Note that all users at this stage are known to
936
  // be ShAmt-sized truncs, or the lshr itself.
937
0
  if (!Add->hasOneUse()) {
938
0
    replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty));
939
0
    eraseInstFromFunction(*AddInst);
940
0
  }
941
942
  // Replace the LShr with a zext of the overflow check.
943
0
  return new ZExtInst(Overflow, Ty);
944
0
}
945
946
// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
947
45.2k
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
948
45.2k
  assert(I.isShift() && "Expected a shift as input");
949
  // We already have all the flags.
950
45.2k
  if (I.getOpcode() == Instruction::Shl) {
951
19.1k
    if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap())
952
2.29k
      return false;
953
26.0k
  } else {
954
26.0k
    if (I.isExact())
955
1.81k
      return false;
956
957
    // shr (shl X, Y), Y
958
24.2k
    if (match(I.getOperand(0), m_Shl(m_Value(), m_Specific(I.getOperand(1))))) {
959
242
      I.setIsExact();
960
242
      return true;
961
242
    }
962
24.2k
  }
963
964
  // Compute what we know about shift count.
965
40.8k
  KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q);
966
40.8k
  unsigned BitWidth = KnownCnt.getBitWidth();
967
  // Since shift produces a poison value if RHS is equal to or larger than the
968
  // bit width, we can safely assume that RHS is less than the bit width.
969
40.8k
  uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1);
970
971
40.8k
  KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q);
972
40.8k
  bool Changed = false;
973
974
40.8k
  if (I.getOpcode() == Instruction::Shl) {
975
    // If we have as many leading zeros than maximum shift cnt we have nuw.
976
16.8k
    if (!I.hasNoUnsignedWrap() && MaxCnt <= KnownAmt.countMinLeadingZeros()) {
977
1.48k
      I.setHasNoUnsignedWrap();
978
1.48k
      Changed = true;
979
1.48k
    }
980
    // If we have more sign bits than maximum shift cnt we have nsw.
981
16.8k
    if (!I.hasNoSignedWrap()) {
982
14.8k
      if (MaxCnt < KnownAmt.countMinSignBits() ||
983
14.8k
          MaxCnt < ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC,
984
13.7k
                                      Q.CxtI, Q.DT)) {
985
1.24k
        I.setHasNoSignedWrap();
986
1.24k
        Changed = true;
987
1.24k
      }
988
14.8k
    }
989
16.8k
    return Changed;
990
16.8k
  }
991
992
  // If we have at least as many trailing zeros as maximum count then we have
993
  // exact.
994
24.0k
  Changed = MaxCnt <= KnownAmt.countMinTrailingZeros();
995
24.0k
  I.setIsExact(Changed);
996
997
24.0k
  return Changed;
998
40.8k
}
999
1000
22.5k
Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
1001
22.5k
  const SimplifyQuery Q = SQ.getWithInstruction(&I);
1002
1003
22.5k
  if (Value *V = simplifyShlInst(I.getOperand(0), I.getOperand(1),
1004
22.5k
                                 I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q))
1005
2.02k
    return replaceInstUsesWith(I, V);
1006
1007
20.5k
  if (Instruction *X = foldVectorBinop(I))
1008
29
    return X;
1009
1010
20.5k
  if (Instruction *V = commonShiftTransforms(I))
1011
1.16k
    return V;
1012
1013
19.3k
  if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder))
1014
34
    return V;
1015
1016
19.3k
  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1017
19.3k
  Type *Ty = I.getType();
1018
19.3k
  unsigned BitWidth = Ty->getScalarSizeInBits();
1019
1020
19.3k
  const APInt *C;
1021
19.3k
  if (match(Op1, m_APInt(C))) {
1022
11.2k
    unsigned ShAmtC = C->getZExtValue();
1023
1024
    // shl (zext X), C --> zext (shl X, C)
1025
    // This is only valid if X would have zeros shifted out.
1026
11.2k
    Value *X;
1027
11.2k
    if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
1028
700
      unsigned SrcWidth = X->getType()->getScalarSizeInBits();
1029
700
      if (ShAmtC < SrcWidth &&
1030
700
          MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), 0, &I))
1031
8
        return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);
1032
700
    }
1033
1034
    // (X >> C) << C --> X & (-1 << C)
1035
11.2k
    if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) {
1036
23
      APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
1037
23
      return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
1038
23
    }
1039
1040
11.2k
    const APInt *C1;
1041
11.2k
    if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(C1)))) &&
1042
11.2k
        C1->ult(BitWidth)) {
1043
1
      unsigned ShrAmt = C1->getZExtValue();
1044
1
      if (ShrAmt < ShAmtC) {
1045
        // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1)
1046
0
        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
1047
0
        auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
1048
0
        NewShl->setHasNoUnsignedWrap(
1049
0
            I.hasNoUnsignedWrap() ||
1050
0
            (ShrAmt &&
1051
0
             cast<Instruction>(Op0)->getOpcode() == Instruction::LShr &&
1052
0
             I.hasNoSignedWrap()));
1053
0
        NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
1054
0
        return NewShl;
1055
0
      }
1056
1
      if (ShrAmt > ShAmtC) {
1057
        // If C1 > C: (X >>?exact C1) << C --> X >>?exact (C1 - C)
1058
1
        Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC);
1059
1
        auto *NewShr = BinaryOperator::Create(
1060
1
            cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
1061
1
        NewShr->setIsExact(true);
1062
1
        return NewShr;
1063
1
      }
1064
1
    }
1065
1066
11.2k
    if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(C1)))) &&
1067
11.2k
        C1->ult(BitWidth)) {
1068
47
      unsigned ShrAmt = C1->getZExtValue();
1069
47
      if (ShrAmt < ShAmtC) {
1070
        // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C)
1071
22
        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt);
1072
22
        auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
1073
22
        NewShl->setHasNoUnsignedWrap(
1074
22
            I.hasNoUnsignedWrap() ||
1075
22
            (ShrAmt &&
1076
22
             cast<Instruction>(Op0)->getOpcode() == Instruction::LShr &&
1077
22
             I.hasNoSignedWrap()));
1078
22
        NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
1079
22
        Builder.Insert(NewShl);
1080
22
        APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
1081
22
        return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
1082
22
      }
1083
25
      if (ShrAmt > ShAmtC) {
1084
        // If C1 > C: (X >>? C1) << C --> (X >>? (C1 - C)) & (-1 << C)
1085
25
        Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC);
1086
25
        auto *OldShr = cast<BinaryOperator>(Op0);
1087
25
        auto *NewShr =
1088
25
            BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff);
1089
25
        NewShr->setIsExact(OldShr->isExact());
1090
25
        Builder.Insert(NewShr);
1091
25
        APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
1092
25
        return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask));
1093
25
      }
1094
25
    }
1095
1096
    // Similar to above, but look through an intermediate trunc instruction.
1097
11.1k
    BinaryOperator *Shr;
1098
11.1k
    if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) &&
1099
11.1k
        match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) {
1100
      // The larger shift direction survives through the transform.
1101
8
      unsigned ShrAmtC = C1->getZExtValue();
1102
8
      unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC;
1103
8
      Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff);
1104
8
      auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl;
1105
1106
      // If C1 > C:
1107
      // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C)
1108
      // If C > C1:
1109
      // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C)
1110
8
      Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff");
1111
8
      Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff");
1112
8
      APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
1113
8
      return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask));
1114
8
    }
1115
1116
11.1k
    if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
1117
0
      unsigned AmtSum = ShAmtC + C1->getZExtValue();
1118
      // Oversized shifts are simplified to zero in InstSimplify.
1119
0
      if (AmtSum < BitWidth)
1120
        // (X << C1) << C2 --> X << (C1 + C2)
1121
0
        return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
1122
0
    }
1123
1124
    // If we have an opposite shift by the same amount, we may be able to
1125
    // reorder binops and shifts to eliminate math/logic.
1126
11.1k
    auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) {
1127
881
      switch (BinOpcode) {
1128
311
      default:
1129
311
        return false;
1130
50
      case Instruction::Add:
1131
73
      case Instruction::And:
1132
382
      case Instruction::Or:
1133
430
      case Instruction::Xor:
1134
570
      case Instruction::Sub:
1135
        // NOTE: Sub is not commutable and the tranforms below may not be valid
1136
        //       when the shift-right is operand 1 (RHS) of the sub.
1137
570
        return true;
1138
881
      }
1139
881
    };
1140
11.1k
    BinaryOperator *Op0BO;
1141
11.1k
    if (match(Op0, m_OneUse(m_BinOp(Op0BO))) &&
1142
11.1k
        isSuitableBinOpcode(Op0BO->getOpcode())) {
1143
      // Commute so shift-right is on LHS of the binop.
1144
      // (Y bop (X >> C)) << C         ->  ((X >> C) bop Y) << C
1145
      // (Y bop ((X >> C) & CC)) << C  ->  (((X >> C) & CC) bop Y) << C
1146
570
      Value *Shr = Op0BO->getOperand(0);
1147
570
      Value *Y = Op0BO->getOperand(1);
1148
570
      Value *X;
1149
570
      const APInt *CC;
1150
570
      if (Op0BO->isCommutative() && Y->hasOneUse() &&
1151
570
          (match(Y, m_Shr(m_Value(), m_Specific(Op1))) ||
1152
357
           match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))),
1153
351
                          m_APInt(CC)))))
1154
6
        std::swap(Shr, Y);
1155
1156
      // ((X >> C) bop Y) << C  ->  (X bop (Y << C)) & (~0 << C)
1157
570
      if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
1158
        // Y << C
1159
13
        Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName());
1160
        // (X bop (Y << C))
1161
13
        Value *B =
1162
13
            Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName());
1163
13
        unsigned Op1Val = C->getLimitedValue(BitWidth);
1164
13
        APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val);
1165
13
        Constant *Mask = ConstantInt::get(Ty, Bits);
1166
13
        return BinaryOperator::CreateAnd(B, Mask);
1167
13
      }
1168
1169
      // (((X >> C) & CC) bop Y) << C  ->  (X & (CC << C)) bop (Y << C)
1170
557
      if (match(Shr,
1171
557
                m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))),
1172
557
                               m_APInt(CC))))) {
1173
        // Y << C
1174
10
        Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName());
1175
        // X & (CC << C)
1176
10
        Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)),
1177
10
                                     X->getName() + ".mask");
1178
10
        return BinaryOperator::Create(Op0BO->getOpcode(), M, YS);
1179
10
      }
1180
557
    }
1181
1182
    // (C1 - X) << C --> (C1 << C) - (X << C)
1183
11.1k
    if (match(Op0, m_OneUse(m_Sub(m_APInt(C1), m_Value(X))))) {
1184
53
      Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C));
1185
53
      Value *NewShift = Builder.CreateShl(X, Op1);
1186
53
      return BinaryOperator::CreateSub(NewLHS, NewShift);
1187
53
    }
1188
11.1k
  }
1189
1190
19.1k
  if (setShiftFlags(I, Q))
1191
1.95k
    return &I;
1192
1193
  // Transform  (x >> y) << y  to  x & (-1 << y)
1194
  // Valid for any type of right-shift.
1195
17.2k
  Value *X;
1196
17.2k
  if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) {
1197
3
    Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
1198
3
    Value *Mask = Builder.CreateShl(AllOnes, Op1);
1199
3
    return BinaryOperator::CreateAnd(Mask, X);
1200
3
  }
1201
1202
17.1k
  Constant *C1;
1203
17.1k
  if (match(Op1, m_Constant(C1))) {
1204
11.0k
    Constant *C2;
1205
11.0k
    Value *X;
1206
    // (X * C2) << C1 --> X * (C2 << C1)
1207
11.0k
    if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
1208
60
      return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
1209
1210
    // shl (zext i1 X), C1 --> select (X, 1 << C1, 0)
1211
11.0k
    if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
1212
113
      auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1);
1213
113
      return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
1214
113
    }
1215
11.0k
  }
1216
1217
17.0k
  if (match(Op0, m_One())) {
1218
    // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
1219
1.03k
    if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
1220
0
      return BinaryOperator::CreateLShr(
1221
0
          ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
1222
1223
    // Canonicalize "extract lowest set bit" using cttz to and-with-negate:
1224
    // 1 << (cttz X) --> -X & X
1225
1.03k
    if (match(Op1,
1226
1.03k
              m_OneUse(m_Intrinsic<Intrinsic::cttz>(m_Value(X), m_Value())))) {
1227
0
      Value *NegX = Builder.CreateNeg(X, "neg");
1228
0
      return BinaryOperator::CreateAnd(NegX, X);
1229
0
    }
1230
1.03k
  }
1231
1232
17.0k
  return nullptr;
1233
17.0k
}
1234
1235
23.5k
Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
1236
23.5k
  if (Value *V = simplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
1237
23.5k
                                  SQ.getWithInstruction(&I)))
1238
2.21k
    return replaceInstUsesWith(I, V);
1239
1240
21.3k
  if (Instruction *X = foldVectorBinop(I))
1241
16
    return X;
1242
1243
21.3k
  if (Instruction *R = commonShiftTransforms(I))
1244
710
    return R;
1245
1246
20.6k
  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1247
20.6k
  Type *Ty = I.getType();
1248
20.6k
  Value *X;
1249
20.6k
  const APInt *C;
1250
20.6k
  unsigned BitWidth = Ty->getScalarSizeInBits();
1251
1252
  // (iN (~X) u>> (N - 1)) --> zext (X > -1)
1253
20.6k
  if (match(Op0, m_OneUse(m_Not(m_Value(X)))) &&
1254
20.6k
      match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)))
1255
4
    return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
1256
1257
20.6k
  if (match(Op1, m_APInt(C))) {
1258
14.6k
    unsigned ShAmtC = C->getZExtValue();
1259
14.6k
    auto *II = dyn_cast<IntrinsicInst>(Op0);
1260
14.6k
    if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC &&
1261
14.6k
        (II->getIntrinsicID() == Intrinsic::ctlz ||
1262
49
         II->getIntrinsicID() == Intrinsic::cttz ||
1263
49
         II->getIntrinsicID() == Intrinsic::ctpop)) {
1264
      // ctlz.i32(x)>>5  --> zext(x == 0)
1265
      // cttz.i32(x)>>5  --> zext(x == 0)
1266
      // ctpop.i32(x)>>5 --> zext(x == -1)
1267
0
      bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
1268
0
      Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
1269
0
      Value *Cmp = Builder.CreateICmpEQ(II->getArgOperand(0), RHS);
1270
0
      return new ZExtInst(Cmp, Ty);
1271
0
    }
1272
1273
14.6k
    Value *X;
1274
14.6k
    const APInt *C1;
1275
14.6k
    if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
1276
223
      if (C1->ult(ShAmtC)) {
1277
19
        unsigned ShlAmtC = C1->getZExtValue();
1278
19
        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShlAmtC);
1279
19
        if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
1280
          // (X <<nuw C1) >>u C --> X >>u (C - C1)
1281
4
          auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
1282
4
          NewLShr->setIsExact(I.isExact());
1283
4
          return NewLShr;
1284
4
        }
1285
15
        if (Op0->hasOneUse()) {
1286
          // (X << C1) >>u C  --> (X >>u (C - C1)) & (-1 >> C)
1287
6
          Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact());
1288
6
          APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
1289
6
          return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
1290
6
        }
1291
204
      } else if (C1->ugt(ShAmtC)) {
1292
190
        unsigned ShlAmtC = C1->getZExtValue();
1293
190
        Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC);
1294
190
        if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
1295
          // (X <<nuw C1) >>u C --> X <<nuw/nsw (C1 - C)
1296
10
          auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
1297
10
          NewShl->setHasNoUnsignedWrap(true);
1298
10
          NewShl->setHasNoSignedWrap(ShAmtC > 0);
1299
10
          return NewShl;
1300
10
        }
1301
180
        if (Op0->hasOneUse()) {
1302
          // (X << C1) >>u C  --> X << (C1 - C) & (-1 >> C)
1303
14
          Value *NewShl = Builder.CreateShl(X, ShiftDiff);
1304
14
          APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
1305
14
          return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
1306
14
        }
1307
180
      } else {
1308
14
        assert(*C1 == ShAmtC);
1309
        // (X << C) >>u C --> X & (-1 >>u C)
1310
0
        APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
1311
14
        return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
1312
14
      }
1313
223
    }
1314
1315
    // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
1316
    // TODO: Consolidate with the more general transform that starts from shl
1317
    //       (the shifts are in the opposite order).
1318
14.6k
    Value *Y;
1319
14.6k
    if (match(Op0,
1320
14.6k
              m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
1321
14.6k
                               m_Value(Y))))) {
1322
0
      Value *NewLshr = Builder.CreateLShr(Y, Op1);
1323
0
      Value *NewAdd = Builder.CreateAdd(NewLshr, X);
1324
0
      unsigned Op1Val = C->getLimitedValue(BitWidth);
1325
0
      APInt Bits = APInt::getLowBitsSet(BitWidth, BitWidth - Op1Val);
1326
0
      Constant *Mask = ConstantInt::get(Ty, Bits);
1327
0
      return BinaryOperator::CreateAnd(NewAdd, Mask);
1328
0
    }
1329
1330
14.6k
    if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) &&
1331
14.6k
        (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) {
1332
18
      assert(ShAmtC < X->getType()->getScalarSizeInBits() &&
1333
18
             "Big shift not simplified to zero?");
1334
      // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN
1335
0
      Value *NewLShr = Builder.CreateLShr(X, ShAmtC);
1336
18
      return new ZExtInst(NewLShr, Ty);
1337
18
    }
1338
1339
14.6k
    if (match(Op0, m_SExt(m_Value(X)))) {
1340
17
      unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits();
1341
      // lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0)
1342
17
      if (SrcTyBitWidth == 1) {
1343
1
        auto *NewC = ConstantInt::get(
1344
1
            Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC));
1345
1
        return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty));
1346
1
      }
1347
1348
16
      if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) &&
1349
16
          Op0->hasOneUse()) {
1350
        // Are we moving the sign bit to the low bit and widening with high
1351
        // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN
1352
3
        if (ShAmtC == BitWidth - 1) {
1353
0
          Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1);
1354
0
          return new ZExtInst(NewLShr, Ty);
1355
0
        }
1356
1357
        // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN
1358
3
        if (ShAmtC == BitWidth - SrcTyBitWidth) {
1359
          // The new shift amount can't be more than the narrow source type.
1360
0
          unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1);
1361
0
          Value *AShr = Builder.CreateAShr(X, NewShAmt);
1362
0
          return new ZExtInst(AShr, Ty);
1363
0
        }
1364
3
      }
1365
16
    }
1366
1367
14.6k
    if (ShAmtC == BitWidth - 1) {
1368
      // lshr i32 or(X,-X), 31 --> zext (X != 0)
1369
869
      if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X)))))
1370
0
        return new ZExtInst(Builder.CreateIsNotNull(X), Ty);
1371
1372
      // lshr i32 (X -nsw Y), 31 --> zext (X < Y)
1373
869
      if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
1374
4
        return new ZExtInst(Builder.CreateICmpSLT(X, Y), Ty);
1375
1376
      // Check if a number is negative and odd:
1377
      // lshr i32 (srem X, 2), 31 --> and (X >> 31), X
1378
865
      if (match(Op0, m_OneUse(m_SRem(m_Value(X), m_SpecificInt(2))))) {
1379
0
        Value *Signbit = Builder.CreateLShr(X, ShAmtC);
1380
0
        return BinaryOperator::CreateAnd(Signbit, X);
1381
0
      }
1382
865
    }
1383
1384
    // (X >>u C1) >>u C --> X >>u (C1 + C)
1385
14.6k
    if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) {
1386
      // Oversized shifts are simplified to zero in InstSimplify.
1387
0
      unsigned AmtSum = ShAmtC + C1->getZExtValue();
1388
0
      if (AmtSum < BitWidth)
1389
0
        return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
1390
0
    }
1391
1392
14.6k
    Instruction *TruncSrc;
1393
14.6k
    if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) &&
1394
14.6k
        match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) {
1395
0
      unsigned SrcWidth = X->getType()->getScalarSizeInBits();
1396
0
      unsigned AmtSum = ShAmtC + C1->getZExtValue();
1397
1398
      // If the combined shift fits in the source width:
1399
      // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC
1400
      //
1401
      // If the first shift covers the number of bits truncated, then the
1402
      // mask instruction is eliminated (and so the use check is relaxed).
1403
0
      if (AmtSum < SrcWidth &&
1404
0
          (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) {
1405
0
        Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift");
1406
0
        Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName());
1407
1408
        // If the first shift does not cover the number of bits truncated, then
1409
        // we require a mask to get rid of high bits in the result.
1410
0
        APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC);
1411
0
        return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC));
1412
0
      }
1413
0
    }
1414
1415
14.6k
    const APInt *MulC;
1416
14.6k
    if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
1417
      // Look for a "splat" mul pattern - it replicates bits across each half of
1418
      // a value, so a right shift is just a mask of the low bits:
1419
      // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
1420
      // TODO: Generalize to allow more than just half-width shifts?
1421
122
      if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() &&
1422
122
          MulC->logBase2() == ShAmtC)
1423
0
        return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
1424
1425
      // The one-use check is not strictly necessary, but codegen may not be
1426
      // able to invert the transform and perf may suffer with an extra mul
1427
      // instruction.
1428
122
      if (Op0->hasOneUse()) {
1429
85
        APInt NewMulC = MulC->lshr(ShAmtC);
1430
        // if c is divisible by (1 << ShAmtC):
1431
        // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
1432
85
        if (MulC->eq(NewMulC.shl(ShAmtC))) {
1433
0
          auto *NewMul =
1434
0
              BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
1435
0
          assert(ShAmtC != 0 &&
1436
0
                 "lshr X, 0 should be handled by simplifyLShrInst.");
1437
0
          NewMul->setHasNoSignedWrap(true);
1438
0
          return NewMul;
1439
0
        }
1440
85
      }
1441
122
    }
1442
1443
    // Try to narrow bswap.
1444
    // In the case where the shift amount equals the bitwidth difference, the
1445
    // shift is eliminated.
1446
14.6k
    if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::bswap>(
1447
14.6k
                       m_OneUse(m_ZExt(m_Value(X))))))) {
1448
0
      unsigned SrcWidth = X->getType()->getScalarSizeInBits();
1449
0
      unsigned WidthDiff = BitWidth - SrcWidth;
1450
0
      if (SrcWidth % 16 == 0) {
1451
0
        Value *NarrowSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X);
1452
0
        if (ShAmtC >= WidthDiff) {
1453
          // (bswap (zext X)) >> C --> zext (bswap X >> C')
1454
0
          Value *NewShift = Builder.CreateLShr(NarrowSwap, ShAmtC - WidthDiff);
1455
0
          return new ZExtInst(NewShift, Ty);
1456
0
        } else {
1457
          // (bswap (zext X)) >> C --> (zext (bswap X)) << C'
1458
0
          Value *NewZExt = Builder.CreateZExt(NarrowSwap, Ty);
1459
0
          Constant *ShiftDiff = ConstantInt::get(Ty, WidthDiff - ShAmtC);
1460
0
          return BinaryOperator::CreateShl(NewZExt, ShiftDiff);
1461
0
        }
1462
0
      }
1463
0
    }
1464
1465
    // Reduce add-carry of bools to logic:
1466
    // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY)
1467
14.6k
    Value *BoolX, *BoolY;
1468
14.6k
    if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) &&
1469
14.6k
        match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) &&
1470
14.6k
        BoolX->getType()->isIntOrIntVectorTy(1) &&
1471
14.6k
        BoolY->getType()->isIntOrIntVectorTy(1) &&
1472
14.6k
        (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) {
1473
0
      Value *And = Builder.CreateAnd(BoolX, BoolY);
1474
0
      return new ZExtInst(And, Ty);
1475
0
    }
1476
14.6k
  }
1477
1478
20.5k
  const SimplifyQuery Q = SQ.getWithInstruction(&I);
1479
20.5k
  if (setShiftFlags(I, Q))
1480
332
    return &I;
1481
1482
  // Transform  (x << y) >> y  to  x & (-1 >> y)
1483
20.2k
  if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
1484
4
    Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
1485
4
    Value *Mask = Builder.CreateLShr(AllOnes, Op1);
1486
4
    return BinaryOperator::CreateAnd(Mask, X);
1487
4
  }
1488
1489
20.2k
  if (Instruction *Overflow = foldLShrOverflowBit(I))
1490
0
    return Overflow;
1491
1492
20.2k
  return nullptr;
1493
20.2k
}
1494
1495
Instruction *
1496
InstCombinerImpl::foldVariableSignZeroExtensionOfVariableHighBitExtract(
1497
5.23k
    BinaryOperator &OldAShr) {
1498
5.23k
  assert(OldAShr.getOpcode() == Instruction::AShr &&
1499
5.23k
         "Must be called with arithmetic right-shift instruction only.");
1500
1501
  // Check that constant C is a splat of the element-wise bitwidth of V.
1502
0
  auto BitWidthSplat = [](Constant *C, Value *V) {
1503
0
    return match(
1504
0
        C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
1505
0
                              APInt(C->getType()->getScalarSizeInBits(),
1506
0
                                    V->getType()->getScalarSizeInBits())));
1507
0
  };
1508
1509
  // It should look like variable-length sign-extension on the outside:
1510
  //   (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits)
1511
5.23k
  Value *NBits;
1512
5.23k
  Instruction *MaybeTrunc;
1513
5.23k
  Constant *C1, *C2;
1514
5.23k
  if (!match(&OldAShr,
1515
5.23k
             m_AShr(m_Shl(m_Instruction(MaybeTrunc),
1516
5.23k
                          m_ZExtOrSelf(m_Sub(m_Constant(C1),
1517
5.23k
                                             m_ZExtOrSelf(m_Value(NBits))))),
1518
5.23k
                    m_ZExtOrSelf(m_Sub(m_Constant(C2),
1519
5.23k
                                       m_ZExtOrSelf(m_Deferred(NBits)))))) ||
1520
5.23k
      !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr))
1521
5.23k
    return nullptr;
1522
1523
  // There may or may not be a truncation after outer two shifts.
1524
0
  Instruction *HighBitExtract;
1525
0
  match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract)));
1526
0
  bool HadTrunc = MaybeTrunc != HighBitExtract;
1527
1528
  // And finally, the innermost part of the pattern must be a right-shift.
1529
0
  Value *X, *NumLowBitsToSkip;
1530
0
  if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip))))
1531
0
    return nullptr;
1532
1533
  // Said right-shift must extract high NBits bits - C0 must be it's bitwidth.
1534
0
  Constant *C0;
1535
0
  if (!match(NumLowBitsToSkip,
1536
0
             m_ZExtOrSelf(
1537
0
                 m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) ||
1538
0
      !BitWidthSplat(C0, HighBitExtract))
1539
0
    return nullptr;
1540
1541
  // Since the NBits is identical for all shifts, if the outermost and
1542
  // innermost shifts are identical, then outermost shifts are redundant.
1543
  // If we had truncation, do keep it though.
1544
0
  if (HighBitExtract->getOpcode() == OldAShr.getOpcode())
1545
0
    return replaceInstUsesWith(OldAShr, MaybeTrunc);
1546
1547
  // Else, if there was a truncation, then we need to ensure that one
1548
  // instruction will go away.
1549
0
  if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
1550
0
    return nullptr;
1551
1552
  // Finally, bypass two innermost shifts, and perform the outermost shift on
1553
  // the operands of the innermost shift.
1554
0
  Instruction *NewAShr =
1555
0
      BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip);
1556
0
  NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness.
1557
0
  if (!HadTrunc)
1558
0
    return NewAShr;
1559
1560
0
  Builder.Insert(NewAShr);
1561
0
  return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
1562
0
}
1563
1564
8.20k
Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
1565
8.20k
  if (Value *V = simplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
1566
8.20k
                                  SQ.getWithInstruction(&I)))
1567
2.18k
    return replaceInstUsesWith(I, V);
1568
1569
6.01k
  if (Instruction *X = foldVectorBinop(I))
1570
7
    return X;
1571
1572
6.00k
  if (Instruction *R = commonShiftTransforms(I))
1573
244
    return R;
1574
1575
5.76k
  Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1576
5.76k
  Type *Ty = I.getType();
1577
5.76k
  unsigned BitWidth = Ty->getScalarSizeInBits();
1578
5.76k
  const APInt *ShAmtAPInt;
1579
5.76k
  if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
1580
2.98k
    unsigned ShAmt = ShAmtAPInt->getZExtValue();
1581
1582
    // If the shift amount equals the difference in width of the destination
1583
    // and source scalar types:
1584
    // ashr (shl (zext X), C), C --> sext X
1585
2.98k
    Value *X;
1586
2.98k
    if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
1587
2.98k
        ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
1588
0
      return new SExtInst(X, Ty);
1589
1590
    // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
1591
    // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
1592
2.98k
    const APInt *ShOp1;
1593
2.98k
    if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
1594
2.98k
        ShOp1->ult(BitWidth)) {
1595
7
      unsigned ShlAmt = ShOp1->getZExtValue();
1596
7
      if (ShlAmt < ShAmt) {
1597
        // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
1598
0
        Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
1599
0
        auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
1600
0
        NewAShr->setIsExact(I.isExact());
1601
0
        return NewAShr;
1602
0
      }
1603
7
      if (ShlAmt > ShAmt) {
1604
        // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
1605
7
        Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
1606
7
        auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
1607
7
        NewShl->setHasNoSignedWrap(true);
1608
7
        return NewShl;
1609
7
      }
1610
7
    }
1611
1612
2.97k
    if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
1613
2.97k
        ShOp1->ult(BitWidth)) {
1614
12
      unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
1615
      // Oversized arithmetic shifts replicate the sign bit.
1616
12
      AmtSum = std::min(AmtSum, BitWidth - 1);
1617
      // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
1618
12
      return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
1619
12
    }
1620
1621
2.96k
    if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) &&
1622
2.96k
        (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) {
1623
      // ashr (sext X), C --> sext (ashr X, C')
1624
22
      Type *SrcTy = X->getType();
1625
22
      ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1);
1626
22
      Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt));
1627
22
      return new SExtInst(NewSh, Ty);
1628
22
    }
1629
1630
2.93k
    if (ShAmt == BitWidth - 1) {
1631
      // ashr i32 or(X,-X), 31 --> sext (X != 0)
1632
976
      if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X)))))
1633
35
        return new SExtInst(Builder.CreateIsNotNull(X), Ty);
1634
1635
      // ashr i32 (X -nsw Y), 31 --> sext (X < Y)
1636
941
      Value *Y;
1637
941
      if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
1638
166
        return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
1639
941
    }
1640
2.93k
  }
1641
1642
5.52k
  const SimplifyQuery Q = SQ.getWithInstruction(&I);
1643
5.52k
  if (setShiftFlags(I, Q))
1644
267
    return &I;
1645
1646
  // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)`
1647
  // as the pattern to splat the lowest bit.
1648
  // FIXME: iff X is already masked, we don't need the one-use check.
1649
5.25k
  Value *X;
1650
5.25k
  if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) &&
1651
5.25k
      match(Op0, m_OneUse(m_Shl(m_Value(X),
1652
828
                                m_SpecificIntAllowUndef(BitWidth - 1))))) {
1653
24
    Constant *Mask = ConstantInt::get(Ty, 1);
1654
    // Retain the knowledge about the ignored lanes.
1655
24
    Mask = Constant::mergeUndefsWith(
1656
24
        Constant::mergeUndefsWith(Mask, cast<Constant>(Op1)),
1657
24
        cast<Constant>(cast<Instruction>(Op0)->getOperand(1)));
1658
24
    X = Builder.CreateAnd(X, Mask);
1659
24
    return BinaryOperator::CreateNeg(X);
1660
24
  }
1661
1662
5.23k
  if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I))
1663
0
    return R;
1664
1665
  // See if we can turn a signed shr into an unsigned shr.
1666
5.23k
  if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) {
1667
365
    Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1);
1668
365
    Lshr->setIsExact(I.isExact());
1669
365
    return Lshr;
1670
365
  }
1671
1672
  // ashr (xor %x, -1), %y  -->  xor (ashr %x, %y), -1
1673
4.86k
  if (match(Op0, m_OneUse(m_Not(m_Value(X))))) {
1674
    // Note that we must drop 'exact'-ness of the shift!
1675
    // Note that we can't keep undef's in -1 vector constant!
1676
8
    auto *NewAShr = Builder.CreateAShr(X, Op1, Op0->getName() + ".not");
1677
8
    return BinaryOperator::CreateNot(NewAShr);
1678
8
  }
1679
1680
4.85k
  return nullptr;
1681
4.86k
}