Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp
Line
Count
Source (jump to first uncovered line)
1
//=== AArch64PostLegalizerLowering.cpp --------------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
///
9
/// \file
10
/// Post-legalization lowering for instructions.
11
///
12
/// This is used to offload pattern matching from the selector.
13
///
14
/// For example, this combiner will notice that a G_SHUFFLE_VECTOR is actually
15
/// a G_ZIP, G_UZP, etc.
16
///
17
/// General optimization combines should be handled by either the
18
/// AArch64PostLegalizerCombiner or the AArch64PreLegalizerCombiner.
19
///
20
//===----------------------------------------------------------------------===//
21
22
#include "AArch64GlobalISelUtils.h"
23
#include "AArch64Subtarget.h"
24
#include "AArch64TargetMachine.h"
25
#include "GISel/AArch64LegalizerInfo.h"
26
#include "MCTargetDesc/AArch64MCTargetDesc.h"
27
#include "TargetInfo/AArch64TargetInfo.h"
28
#include "Utils/AArch64BaseInfo.h"
29
#include "llvm/CodeGen/GlobalISel/Combiner.h"
30
#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
31
#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
32
#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h"
33
#include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
34
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
35
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
36
#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
37
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
38
#include "llvm/CodeGen/GlobalISel/Utils.h"
39
#include "llvm/CodeGen/MachineFunctionPass.h"
40
#include "llvm/CodeGen/MachineInstrBuilder.h"
41
#include "llvm/CodeGen/MachineRegisterInfo.h"
42
#include "llvm/CodeGen/TargetOpcodes.h"
43
#include "llvm/CodeGen/TargetPassConfig.h"
44
#include "llvm/IR/InstrTypes.h"
45
#include "llvm/InitializePasses.h"
46
#include "llvm/Support/Debug.h"
47
#include "llvm/Support/ErrorHandling.h"
48
#include <optional>
49
50
#define GET_GICOMBINER_DEPS
51
#include "AArch64GenPostLegalizeGILowering.inc"
52
#undef GET_GICOMBINER_DEPS
53
54
#define DEBUG_TYPE "aarch64-postlegalizer-lowering"
55
56
using namespace llvm;
57
using namespace MIPatternMatch;
58
using namespace AArch64GISelUtils;
59
60
namespace {
61
62
#define GET_GICOMBINER_TYPES
63
#include "AArch64GenPostLegalizeGILowering.inc"
64
#undef GET_GICOMBINER_TYPES
65
66
/// Represents a pseudo instruction which replaces a G_SHUFFLE_VECTOR.
67
///
68
/// Used for matching target-supported shuffles before codegen.
69
struct ShuffleVectorPseudo {
70
  unsigned Opc;                 ///< Opcode for the instruction. (E.g. G_ZIP1)
71
  Register Dst;                 ///< Destination register.
72
  SmallVector<SrcOp, 2> SrcOps; ///< Source registers.
73
  ShuffleVectorPseudo(unsigned Opc, Register Dst,
74
                      std::initializer_list<SrcOp> SrcOps)
75
226
      : Opc(Opc), Dst(Dst), SrcOps(SrcOps){};
76
405k
  ShuffleVectorPseudo() = default;
77
};
78
79
/// Check if a vector shuffle corresponds to a REV instruction with the
80
/// specified blocksize.
81
bool isREVMask(ArrayRef<int> M, unsigned EltSize, unsigned NumElts,
82
664
               unsigned BlockSize) {
83
664
  assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
84
664
         "Only possible block sizes for REV are: 16, 32, 64");
85
0
  assert(EltSize != 64 && "EltSize cannot be 64 for REV mask.");
86
87
0
  unsigned BlockElts = M[0] + 1;
88
89
  // If the first shuffle index is UNDEF, be optimistic.
90
664
  if (M[0] < 0)
91
137
    BlockElts = BlockSize / EltSize;
92
93
664
  if (BlockSize <= EltSize || BlockSize != BlockElts * EltSize)
94
423
    return false;
95
96
827
  for (unsigned i = 0; i < NumElts; ++i) {
97
    // Ignore undef indices.
98
714
    if (M[i] < 0)
99
404
      continue;
100
310
    if (static_cast<unsigned>(M[i]) !=
101
310
        (i - i % BlockElts) + (BlockElts - 1 - i % BlockElts))
102
128
      return false;
103
310
  }
104
105
113
  return true;
106
241
}
107
108
/// Determines if \p M is a shuffle vector mask for a TRN of \p NumElts.
109
/// Whether or not G_TRN1 or G_TRN2 should be used is stored in \p WhichResult.
110
532
bool isTRNMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
111
532
  if (NumElts % 2 != 0)
112
0
    return false;
113
532
  WhichResult = (M[0] == 0 ? 0 : 1);
114
591
  for (unsigned i = 0; i < NumElts; i += 2) {
115
590
    if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != i + WhichResult) ||
116
590
        (M[i + 1] >= 0 &&
117
418
         static_cast<unsigned>(M[i + 1]) != i + NumElts + WhichResult))
118
531
      return false;
119
590
  }
120
1
  return true;
121
532
}
122
123
/// Check if a G_EXT instruction can handle a shuffle mask \p M when the vector
124
/// sources of the shuffle are different.
125
std::optional<std::pair<bool, uint64_t>> getExtMask(ArrayRef<int> M,
126
620
                                                    unsigned NumElts) {
127
  // Look for the first non-undef element.
128
693
  auto FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
129
620
  if (FirstRealElt == M.end())
130
6
    return std::nullopt;
131
132
  // Use APInt to handle overflow when calculating expected element.
133
614
  unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
134
614
  APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1);
135
136
  // The following shuffle indices must be the successive elements after the
137
  // first real element.
138
614
  if (any_of(
139
614
          make_range(std::next(FirstRealElt), M.end()),
140
966
          [&ExpectedElt](int Elt) { return Elt != ExpectedElt++ && Elt >= 0; }))
141
536
    return std::nullopt;
142
143
  // The index of an EXT is the first element if it is not UNDEF.
144
  // Watch out for the beginning UNDEFs. The EXT index should be the expected
145
  // value of the first element.  E.g.
146
  // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
147
  // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
148
  // ExpectedElt is the last mask index plus 1.
149
78
  uint64_t Imm = ExpectedElt.getZExtValue();
150
78
  bool ReverseExt = false;
151
152
  // There are two difference cases requiring to reverse input vectors.
153
  // For example, for vector <4 x i32> we have the following cases,
154
  // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
155
  // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
156
  // For both cases, we finally use mask <5, 6, 7, 0>, which requires
157
  // to reverse two input vectors.
158
78
  if (Imm < NumElts)
159
1
    ReverseExt = true;
160
77
  else
161
77
    Imm -= NumElts;
162
78
  return std::make_pair(ReverseExt, Imm);
163
614
}
164
165
/// Determines if \p M is a shuffle vector mask for a UZP of \p NumElts.
166
/// Whether or not G_UZP1 or G_UZP2 should be used is stored in \p WhichResult.
167
532
bool isUZPMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
168
532
  WhichResult = (M[0] == 0 ? 0 : 1);
169
964
  for (unsigned i = 0; i != NumElts; ++i) {
170
    // Skip undef indices.
171
964
    if (M[i] < 0)
172
58
      continue;
173
906
    if (static_cast<unsigned>(M[i]) != 2 * i + WhichResult)
174
532
      return false;
175
906
  }
176
0
  return true;
177
532
}
178
179
/// \return true if \p M is a zip mask for a shuffle vector of \p NumElts.
180
/// Whether or not G_ZIP1 or G_ZIP2 should be used is stored in \p WhichResult.
181
538
bool isZipMask(ArrayRef<int> M, unsigned NumElts, unsigned &WhichResult) {
182
538
  if (NumElts % 2 != 0)
183
0
    return false;
184
185
  // 0 means use ZIP1, 1 means use ZIP2.
186
538
  WhichResult = (M[0] == 0 ? 0 : 1);
187
538
  unsigned Idx = WhichResult * NumElts / 2;
188
579
  for (unsigned i = 0; i != NumElts; i += 2) {
189
573
    if ((M[i] >= 0 && static_cast<unsigned>(M[i]) != Idx) ||
190
573
        (M[i + 1] >= 0 && static_cast<unsigned>(M[i + 1]) != Idx + NumElts))
191
532
      return false;
192
41
    Idx += 1;
193
41
  }
194
6
  return true;
195
538
}
196
197
/// Helper function for matchINS.
198
///
199
/// \returns a value when \p M is an ins mask for \p NumInputElements.
200
///
201
/// First element of the returned pair is true when the produced
202
/// G_INSERT_VECTOR_ELT destination should be the LHS of the G_SHUFFLE_VECTOR.
203
///
204
/// Second element is the destination lane for the G_INSERT_VECTOR_ELT.
205
std::optional<std::pair<bool, int>> isINSMask(ArrayRef<int> M,
206
463
                                              int NumInputElements) {
207
463
  if (M.size() != static_cast<size_t>(NumInputElements))
208
0
    return std::nullopt;
209
463
  int NumLHSMatch = 0, NumRHSMatch = 0;
210
463
  int LastLHSMismatch = -1, LastRHSMismatch = -1;
211
2.39k
  for (int Idx = 0; Idx < NumInputElements; ++Idx) {
212
1.93k
    if (M[Idx] == -1) {
213
147
      ++NumLHSMatch;
214
147
      ++NumRHSMatch;
215
147
      continue;
216
147
    }
217
1.78k
    M[Idx] == Idx ? ++NumLHSMatch : LastLHSMismatch = Idx;
218
1.78k
    M[Idx] == Idx + NumInputElements ? ++NumRHSMatch : LastRHSMismatch = Idx;
219
1.78k
  }
220
463
  const int NumNeededToMatch = NumInputElements - 1;
221
463
  if (NumLHSMatch == NumNeededToMatch)
222
114
    return std::make_pair(true, LastLHSMismatch);
223
349
  if (NumRHSMatch == NumNeededToMatch)
224
49
    return std::make_pair(false, LastRHSMismatch);
225
300
  return std::nullopt;
226
349
}
227
228
/// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with a
229
/// G_REV instruction. Returns the appropriate G_REV opcode in \p Opc.
230
bool matchREV(MachineInstr &MI, MachineRegisterInfo &MRI,
231
733
              ShuffleVectorPseudo &MatchInfo) {
232
733
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
233
0
  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
234
733
  Register Dst = MI.getOperand(0).getReg();
235
733
  Register Src = MI.getOperand(1).getReg();
236
733
  LLT Ty = MRI.getType(Dst);
237
733
  unsigned EltSize = Ty.getScalarSizeInBits();
238
239
  // Element size for a rev cannot be 64.
240
733
  if (EltSize == 64)
241
69
    return false;
242
243
664
  unsigned NumElts = Ty.getNumElements();
244
245
  // Try to produce G_REV64
246
664
  if (isREVMask(ShuffleMask, EltSize, NumElts, 64)) {
247
113
    MatchInfo = ShuffleVectorPseudo(AArch64::G_REV64, Dst, {Src});
248
113
    return true;
249
113
  }
250
251
  // TODO: Produce G_REV32 and G_REV16 once we have proper legalization support.
252
  // This should be identical to above, but with a constant 32 and constant
253
  // 16.
254
551
  return false;
255
664
}
256
257
/// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
258
/// a G_TRN1 or G_TRN2 instruction.
259
bool matchTRN(MachineInstr &MI, MachineRegisterInfo &MRI,
260
532
              ShuffleVectorPseudo &MatchInfo) {
261
532
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
262
0
  unsigned WhichResult;
263
532
  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
264
532
  Register Dst = MI.getOperand(0).getReg();
265
532
  unsigned NumElts = MRI.getType(Dst).getNumElements();
266
532
  if (!isTRNMask(ShuffleMask, NumElts, WhichResult))
267
531
    return false;
268
1
  unsigned Opc = (WhichResult == 0) ? AArch64::G_TRN1 : AArch64::G_TRN2;
269
1
  Register V1 = MI.getOperand(1).getReg();
270
1
  Register V2 = MI.getOperand(2).getReg();
271
1
  MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
272
1
  return true;
273
532
}
274
275
/// \return true if a G_SHUFFLE_VECTOR instruction \p MI can be replaced with
276
/// a G_UZP1 or G_UZP2 instruction.
277
///
278
/// \param [in] MI - The shuffle vector instruction.
279
/// \param [out] MatchInfo - Either G_UZP1 or G_UZP2 on success.
280
bool matchUZP(MachineInstr &MI, MachineRegisterInfo &MRI,
281
532
              ShuffleVectorPseudo &MatchInfo) {
282
532
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
283
0
  unsigned WhichResult;
284
532
  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
285
532
  Register Dst = MI.getOperand(0).getReg();
286
532
  unsigned NumElts = MRI.getType(Dst).getNumElements();
287
532
  if (!isUZPMask(ShuffleMask, NumElts, WhichResult))
288
532
    return false;
289
0
  unsigned Opc = (WhichResult == 0) ? AArch64::G_UZP1 : AArch64::G_UZP2;
290
0
  Register V1 = MI.getOperand(1).getReg();
291
0
  Register V2 = MI.getOperand(2).getReg();
292
0
  MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
293
0
  return true;
294
532
}
295
296
bool matchZip(MachineInstr &MI, MachineRegisterInfo &MRI,
297
538
              ShuffleVectorPseudo &MatchInfo) {
298
538
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
299
0
  unsigned WhichResult;
300
538
  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
301
538
  Register Dst = MI.getOperand(0).getReg();
302
538
  unsigned NumElts = MRI.getType(Dst).getNumElements();
303
538
  if (!isZipMask(ShuffleMask, NumElts, WhichResult))
304
532
    return false;
305
6
  unsigned Opc = (WhichResult == 0) ? AArch64::G_ZIP1 : AArch64::G_ZIP2;
306
6
  Register V1 = MI.getOperand(1).getReg();
307
6
  Register V2 = MI.getOperand(2).getReg();
308
6
  MatchInfo = ShuffleVectorPseudo(Opc, Dst, {V1, V2});
309
6
  return true;
310
538
}
311
312
/// Helper function for matchDup.
313
bool matchDupFromInsertVectorElt(int Lane, MachineInstr &MI,
314
                                 MachineRegisterInfo &MRI,
315
209
                                 ShuffleVectorPseudo &MatchInfo) {
316
209
  if (Lane != 0)
317
34
    return false;
318
319
  // Try to match a vector splat operation into a dup instruction.
320
  // We're looking for this pattern:
321
  //
322
  // %scalar:gpr(s64) = COPY $x0
323
  // %undef:fpr(<2 x s64>) = G_IMPLICIT_DEF
324
  // %cst0:gpr(s32) = G_CONSTANT i32 0
325
  // %zerovec:fpr(<2 x s32>) = G_BUILD_VECTOR %cst0(s32), %cst0(s32)
326
  // %ins:fpr(<2 x s64>) = G_INSERT_VECTOR_ELT %undef, %scalar(s64), %cst0(s32)
327
  // %splat:fpr(<2 x s64>) = G_SHUFFLE_VECTOR %ins(<2 x s64>), %undef,
328
  // %zerovec(<2 x s32>)
329
  //
330
  // ...into:
331
  // %splat = G_DUP %scalar
332
333
  // Begin matching the insert.
334
175
  auto *InsMI = getOpcodeDef(TargetOpcode::G_INSERT_VECTOR_ELT,
335
175
                             MI.getOperand(1).getReg(), MRI);
336
175
  if (!InsMI)
337
134
    return false;
338
  // Match the undef vector operand.
339
41
  if (!getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, InsMI->getOperand(1).getReg(),
340
41
                    MRI))
341
19
    return false;
342
343
  // Match the index constant 0.
344
22
  if (!mi_match(InsMI->getOperand(3).getReg(), MRI, m_ZeroInt()))
345
7
    return false;
346
347
15
  MatchInfo = ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(),
348
15
                                  {InsMI->getOperand(2).getReg()});
349
15
  return true;
350
22
}
351
352
/// Helper function for matchDup.
353
bool matchDupFromBuildVector(int Lane, MachineInstr &MI,
354
                             MachineRegisterInfo &MRI,
355
194
                             ShuffleVectorPseudo &MatchInfo) {
356
194
  assert(Lane >= 0 && "Expected positive lane?");
357
  // Test if the LHS is a BUILD_VECTOR. If it is, then we can just reference the
358
  // lane's definition directly.
359
0
  auto *BuildVecMI = getOpcodeDef(TargetOpcode::G_BUILD_VECTOR,
360
194
                                  MI.getOperand(1).getReg(), MRI);
361
194
  if (!BuildVecMI)
362
185
    return false;
363
9
  Register Reg = BuildVecMI->getOperand(Lane + 1).getReg();
364
9
  MatchInfo =
365
9
      ShuffleVectorPseudo(AArch64::G_DUP, MI.getOperand(0).getReg(), {Reg});
366
9
  return true;
367
194
}
368
369
bool matchDup(MachineInstr &MI, MachineRegisterInfo &MRI,
370
757
              ShuffleVectorPseudo &MatchInfo) {
371
757
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
372
0
  auto MaybeLane = getSplatIndex(MI);
373
757
  if (!MaybeLane)
374
548
    return false;
375
209
  int Lane = *MaybeLane;
376
  // If this is undef splat, generate it via "just" vdup, if possible.
377
209
  if (Lane < 0)
378
0
    Lane = 0;
379
209
  if (matchDupFromInsertVectorElt(Lane, MI, MRI, MatchInfo))
380
15
    return true;
381
194
  if (matchDupFromBuildVector(Lane, MI, MRI, MatchInfo))
382
9
    return true;
383
185
  return false;
384
194
}
385
386
// Check if an EXT instruction can handle the shuffle mask when the vector
387
// sources of the shuffle are the same.
388
147
bool isSingletonExtMask(ArrayRef<int> M, LLT Ty) {
389
147
  unsigned NumElts = Ty.getNumElements();
390
391
  // Assume that the first shuffle index is not UNDEF.  Fail if it is.
392
147
  if (M[0] < 0)
393
9
    return false;
394
395
  // If this is a VEXT shuffle, the immediate value is the index of the first
396
  // element.  The other shuffle indices must be the successive elements after
397
  // the first one.
398
138
  unsigned ExpectedElt = M[0];
399
172
  for (unsigned I = 1; I < NumElts; ++I) {
400
    // Increment the expected index.  If it wraps around, just follow it
401
    // back to index zero and keep going.
402
168
    ++ExpectedElt;
403
168
    if (ExpectedElt == NumElts)
404
7
      ExpectedElt = 0;
405
406
168
    if (M[I] < 0)
407
0
      continue; // Ignore UNDEF indices.
408
168
    if (ExpectedElt != static_cast<unsigned>(M[I]))
409
134
      return false;
410
168
  }
411
412
4
  return true;
413
138
}
414
415
bool matchEXT(MachineInstr &MI, MachineRegisterInfo &MRI,
416
620
              ShuffleVectorPseudo &MatchInfo) {
417
620
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
418
0
  Register Dst = MI.getOperand(0).getReg();
419
620
  LLT DstTy = MRI.getType(Dst);
420
620
  Register V1 = MI.getOperand(1).getReg();
421
620
  Register V2 = MI.getOperand(2).getReg();
422
620
  auto Mask = MI.getOperand(3).getShuffleMask();
423
620
  uint64_t Imm;
424
620
  auto ExtInfo = getExtMask(Mask, DstTy.getNumElements());
425
620
  uint64_t ExtFactor = MRI.getType(V1).getScalarSizeInBits() / 8;
426
427
620
  if (!ExtInfo) {
428
542
    if (!getOpcodeDef<GImplicitDef>(V2, MRI) ||
429
542
        !isSingletonExtMask(Mask, DstTy))
430
538
      return false;
431
432
4
    Imm = Mask[0] * ExtFactor;
433
4
    MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V1, Imm});
434
4
    return true;
435
542
  }
436
78
  bool ReverseExt;
437
78
  std::tie(ReverseExt, Imm) = *ExtInfo;
438
78
  if (ReverseExt)
439
1
    std::swap(V1, V2);
440
78
  Imm *= ExtFactor;
441
78
  MatchInfo = ShuffleVectorPseudo(AArch64::G_EXT, Dst, {V1, V2, Imm});
442
78
  return true;
443
620
}
444
445
/// Replace a G_SHUFFLE_VECTOR instruction with a pseudo.
446
/// \p Opc is the opcode to use. \p MI is the G_SHUFFLE_VECTOR.
447
void applyShuffleVectorPseudo(MachineInstr &MI,
448
144
                              ShuffleVectorPseudo &MatchInfo) {
449
144
  MachineIRBuilder MIRBuilder(MI);
450
144
  MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst}, MatchInfo.SrcOps);
451
144
  MI.eraseFromParent();
452
144
}
453
454
/// Replace a G_SHUFFLE_VECTOR instruction with G_EXT.
455
/// Special-cased because the constant operand must be emitted as a G_CONSTANT
456
/// for the imported tablegen patterns to work.
457
82
void applyEXT(MachineInstr &MI, ShuffleVectorPseudo &MatchInfo) {
458
82
  MachineIRBuilder MIRBuilder(MI);
459
82
  if (MatchInfo.SrcOps[2].getImm() == 0)
460
76
    MIRBuilder.buildCopy(MatchInfo.Dst, MatchInfo.SrcOps[0]);
461
6
  else {
462
    // Tablegen patterns expect an i32 G_CONSTANT as the final op.
463
6
    auto Cst =
464
6
        MIRBuilder.buildConstant(LLT::scalar(32), MatchInfo.SrcOps[2].getImm());
465
6
    MIRBuilder.buildInstr(MatchInfo.Opc, {MatchInfo.Dst},
466
6
                          {MatchInfo.SrcOps[0], MatchInfo.SrcOps[1], Cst});
467
6
  }
468
82
  MI.eraseFromParent();
469
82
}
470
471
/// Match a G_SHUFFLE_VECTOR with a mask which corresponds to a
472
/// G_INSERT_VECTOR_ELT and G_EXTRACT_VECTOR_ELT pair.
473
///
474
/// e.g.
475
///   %shuf = G_SHUFFLE_VECTOR %left, %right, shufflemask(0, 0)
476
///
477
/// Can be represented as
478
///
479
///   %extract = G_EXTRACT_VECTOR_ELT %left, 0
480
///   %ins = G_INSERT_VECTOR_ELT %left, %extract, 1
481
///
482
bool matchINS(MachineInstr &MI, MachineRegisterInfo &MRI,
483
463
              std::tuple<Register, int, Register, int> &MatchInfo) {
484
463
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
485
0
  ArrayRef<int> ShuffleMask = MI.getOperand(3).getShuffleMask();
486
463
  Register Dst = MI.getOperand(0).getReg();
487
463
  int NumElts = MRI.getType(Dst).getNumElements();
488
463
  auto DstIsLeftAndDstLane = isINSMask(ShuffleMask, NumElts);
489
463
  if (!DstIsLeftAndDstLane)
490
300
    return false;
491
163
  bool DstIsLeft;
492
163
  int DstLane;
493
163
  std::tie(DstIsLeft, DstLane) = *DstIsLeftAndDstLane;
494
163
  Register Left = MI.getOperand(1).getReg();
495
163
  Register Right = MI.getOperand(2).getReg();
496
163
  Register DstVec = DstIsLeft ? Left : Right;
497
163
  Register SrcVec = Left;
498
499
163
  int SrcLane = ShuffleMask[DstLane];
500
163
  if (SrcLane >= NumElts) {
501
110
    SrcVec = Right;
502
110
    SrcLane -= NumElts;
503
110
  }
504
505
163
  MatchInfo = std::make_tuple(DstVec, DstLane, SrcVec, SrcLane);
506
163
  return true;
507
463
}
508
509
void applyINS(MachineInstr &MI, MachineRegisterInfo &MRI,
510
              MachineIRBuilder &Builder,
511
163
              std::tuple<Register, int, Register, int> &MatchInfo) {
512
163
  Builder.setInstrAndDebugLoc(MI);
513
163
  Register Dst = MI.getOperand(0).getReg();
514
163
  auto ScalarTy = MRI.getType(Dst).getElementType();
515
163
  Register DstVec, SrcVec;
516
163
  int DstLane, SrcLane;
517
163
  std::tie(DstVec, DstLane, SrcVec, SrcLane) = MatchInfo;
518
163
  auto SrcCst = Builder.buildConstant(LLT::scalar(64), SrcLane);
519
163
  auto Extract = Builder.buildExtractVectorElement(ScalarTy, SrcVec, SrcCst);
520
163
  auto DstCst = Builder.buildConstant(LLT::scalar(64), DstLane);
521
163
  Builder.buildInsertVectorElement(Dst, DstVec, Extract, DstCst);
522
163
  MI.eraseFromParent();
523
163
}
524
525
/// isVShiftRImm - Check if this is a valid vector for the immediate
526
/// operand of a vector shift right operation. The value must be in the range:
527
///   1 <= Value <= ElementBits for a right shift.
528
bool isVShiftRImm(Register Reg, MachineRegisterInfo &MRI, LLT Ty,
529
224
                  int64_t &Cnt) {
530
224
  assert(Ty.isVector() && "vector shift count is not a vector type");
531
0
  MachineInstr *MI = MRI.getVRegDef(Reg);
532
224
  auto Cst = getAArch64VectorSplatScalar(*MI, MRI);
533
224
  if (!Cst)
534
60
    return false;
535
164
  Cnt = *Cst;
536
164
  int64_t ElementBits = Ty.getScalarSizeInBits();
537
164
  return Cnt >= 1 && Cnt <= ElementBits;
538
224
}
539
540
/// Match a vector G_ASHR or G_LSHR with a valid immediate shift.
541
bool matchVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
542
1.77k
                       int64_t &Imm) {
543
1.77k
  assert(MI.getOpcode() == TargetOpcode::G_ASHR ||
544
1.77k
         MI.getOpcode() == TargetOpcode::G_LSHR);
545
0
  LLT Ty = MRI.getType(MI.getOperand(1).getReg());
546
1.77k
  if (!Ty.isVector())
547
1.55k
    return false;
548
224
  return isVShiftRImm(MI.getOperand(2).getReg(), MRI, Ty, Imm);
549
1.77k
}
550
551
void applyVAshrLshrImm(MachineInstr &MI, MachineRegisterInfo &MRI,
552
164
                       int64_t &Imm) {
553
164
  unsigned Opc = MI.getOpcode();
554
164
  assert(Opc == TargetOpcode::G_ASHR || Opc == TargetOpcode::G_LSHR);
555
0
  unsigned NewOpc =
556
164
      Opc == TargetOpcode::G_ASHR ? AArch64::G_VASHR : AArch64::G_VLSHR;
557
164
  MachineIRBuilder MIB(MI);
558
164
  auto ImmDef = MIB.buildConstant(LLT::scalar(32), Imm);
559
164
  MIB.buildInstr(NewOpc, {MI.getOperand(0)}, {MI.getOperand(1), ImmDef});
560
164
  MI.eraseFromParent();
561
164
}
562
563
/// Determine if it is possible to modify the \p RHS and predicate \p P of a
564
/// G_ICMP instruction such that the right-hand side is an arithmetic immediate.
565
///
566
/// \returns A pair containing the updated immediate and predicate which may
567
/// be used to optimize the instruction.
568
///
569
/// \note This assumes that the comparison has been legalized.
570
std::optional<std::pair<uint64_t, CmpInst::Predicate>>
571
tryAdjustICmpImmAndPred(Register RHS, CmpInst::Predicate P,
572
10.3k
                        const MachineRegisterInfo &MRI) {
573
10.3k
  const auto &Ty = MRI.getType(RHS);
574
10.3k
  if (Ty.isVector())
575
60
    return std::nullopt;
576
10.2k
  unsigned Size = Ty.getSizeInBits();
577
10.2k
  assert((Size == 32 || Size == 64) && "Expected 32 or 64 bit compare only?");
578
579
  // If the RHS is not a constant, or the RHS is already a valid arithmetic
580
  // immediate, then there is nothing to change.
581
0
  auto ValAndVReg = getIConstantVRegValWithLookThrough(RHS, MRI);
582
10.2k
  if (!ValAndVReg)
583
5.60k
    return std::nullopt;
584
4.63k
  uint64_t C = ValAndVReg->Value.getZExtValue();
585
4.63k
  if (isLegalArithImmed(C))
586
3.17k
    return std::nullopt;
587
588
  // We have a non-arithmetic immediate. Check if adjusting the immediate and
589
  // adjusting the predicate will result in a legal arithmetic immediate.
590
1.46k
  switch (P) {
591
233
  default:
592
233
    return std::nullopt;
593
145
  case CmpInst::ICMP_SLT:
594
284
  case CmpInst::ICMP_SGE:
595
    // Check for
596
    //
597
    // x slt c => x sle c - 1
598
    // x sge c => x sgt c - 1
599
    //
600
    // When c is not the smallest possible negative number.
601
284
    if ((Size == 64 && static_cast<int64_t>(C) == INT64_MIN) ||
602
284
        (Size == 32 && static_cast<int32_t>(C) == INT32_MIN))
603
65
      return std::nullopt;
604
219
    P = (P == CmpInst::ICMP_SLT) ? CmpInst::ICMP_SLE : CmpInst::ICMP_SGT;
605
219
    C -= 1;
606
219
    break;
607
158
  case CmpInst::ICMP_ULT:
608
339
  case CmpInst::ICMP_UGE:
609
    // Check for
610
    //
611
    // x ult c => x ule c - 1
612
    // x uge c => x ugt c - 1
613
    //
614
    // When c is not zero.
615
339
    if (C == 0)
616
0
      return std::nullopt;
617
339
    P = (P == CmpInst::ICMP_ULT) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
618
339
    C -= 1;
619
339
    break;
620
177
  case CmpInst::ICMP_SLE:
621
415
  case CmpInst::ICMP_SGT:
622
    // Check for
623
    //
624
    // x sle c => x slt c + 1
625
    // x sgt c => s sge c + 1
626
    //
627
    // When c is not the largest possible signed integer.
628
415
    if ((Size == 32 && static_cast<int32_t>(C) == INT32_MAX) ||
629
415
        (Size == 64 && static_cast<int64_t>(C) == INT64_MAX))
630
56
      return std::nullopt;
631
359
    P = (P == CmpInst::ICMP_SLE) ? CmpInst::ICMP_SLT : CmpInst::ICMP_SGE;
632
359
    C += 1;
633
359
    break;
634
99
  case CmpInst::ICMP_ULE:
635
190
  case CmpInst::ICMP_UGT:
636
    // Check for
637
    //
638
    // x ule c => x ult c + 1
639
    // x ugt c => s uge c + 1
640
    //
641
    // When c is not the largest possible unsigned integer.
642
190
    if ((Size == 32 && static_cast<uint32_t>(C) == UINT32_MAX) ||
643
190
        (Size == 64 && C == UINT64_MAX))
644
52
      return std::nullopt;
645
138
    P = (P == CmpInst::ICMP_ULE) ? CmpInst::ICMP_ULT : CmpInst::ICMP_UGE;
646
138
    C += 1;
647
138
    break;
648
1.46k
  }
649
650
  // Check if the new constant is valid, and return the updated constant and
651
  // predicate if it is.
652
1.05k
  if (Size == 32)
653
726
    C = static_cast<uint32_t>(C);
654
1.05k
  if (!isLegalArithImmed(C))
655
872
    return std::nullopt;
656
183
  return {{C, P}};
657
1.05k
}
658
659
/// Determine whether or not it is possible to update the RHS and predicate of
660
/// a G_ICMP instruction such that the RHS will be selected as an arithmetic
661
/// immediate.
662
///
663
/// \p MI - The G_ICMP instruction
664
/// \p MatchInfo - The new RHS immediate and predicate on success
665
///
666
/// See tryAdjustICmpImmAndPred for valid transformations.
667
bool matchAdjustICmpImmAndPred(
668
    MachineInstr &MI, const MachineRegisterInfo &MRI,
669
10.3k
    std::pair<uint64_t, CmpInst::Predicate> &MatchInfo) {
670
10.3k
  assert(MI.getOpcode() == TargetOpcode::G_ICMP);
671
0
  Register RHS = MI.getOperand(3).getReg();
672
10.3k
  auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
673
10.3k
  if (auto MaybeNewImmAndPred = tryAdjustICmpImmAndPred(RHS, Pred, MRI)) {
674
183
    MatchInfo = *MaybeNewImmAndPred;
675
183
    return true;
676
183
  }
677
10.1k
  return false;
678
10.3k
}
679
680
void applyAdjustICmpImmAndPred(
681
    MachineInstr &MI, std::pair<uint64_t, CmpInst::Predicate> &MatchInfo,
682
183
    MachineIRBuilder &MIB, GISelChangeObserver &Observer) {
683
183
  MIB.setInstrAndDebugLoc(MI);
684
183
  MachineOperand &RHS = MI.getOperand(3);
685
183
  MachineRegisterInfo &MRI = *MIB.getMRI();
686
183
  auto Cst = MIB.buildConstant(MRI.cloneVirtualRegister(RHS.getReg()),
687
183
                               MatchInfo.first);
688
183
  Observer.changingInstr(MI);
689
183
  RHS.setReg(Cst->getOperand(0).getReg());
690
183
  MI.getOperand(1).setPredicate(MatchInfo.second);
691
183
  Observer.changedInstr(MI);
692
183
}
693
694
bool matchDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
695
531
                  std::pair<unsigned, int> &MatchInfo) {
696
531
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
697
0
  Register Src1Reg = MI.getOperand(1).getReg();
698
531
  const LLT SrcTy = MRI.getType(Src1Reg);
699
531
  const LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
700
701
531
  auto LaneIdx = getSplatIndex(MI);
702
531
  if (!LaneIdx)
703
463
    return false;
704
705
  // The lane idx should be within the first source vector.
706
68
  if (*LaneIdx >= SrcTy.getNumElements())
707
0
    return false;
708
709
68
  if (DstTy != SrcTy)
710
0
    return false;
711
712
68
  LLT ScalarTy = SrcTy.getElementType();
713
68
  unsigned ScalarSize = ScalarTy.getSizeInBits();
714
715
68
  unsigned Opc = 0;
716
68
  switch (SrcTy.getNumElements()) {
717
40
  case 2:
718
40
    if (ScalarSize == 64)
719
0
      Opc = AArch64::G_DUPLANE64;
720
40
    else if (ScalarSize == 32)
721
40
      Opc = AArch64::G_DUPLANE32;
722
40
    break;
723
23
  case 4:
724
23
    if (ScalarSize == 32)
725
23
      Opc = AArch64::G_DUPLANE32;
726
0
    else if (ScalarSize == 16)
727
0
      Opc = AArch64::G_DUPLANE16;
728
23
    break;
729
5
  case 8:
730
5
    if (ScalarSize == 8)
731
0
      Opc = AArch64::G_DUPLANE8;
732
5
    else if (ScalarSize == 16)
733
5
      Opc = AArch64::G_DUPLANE16;
734
5
    break;
735
0
  case 16:
736
0
    if (ScalarSize == 8)
737
0
      Opc = AArch64::G_DUPLANE8;
738
0
    break;
739
0
  default:
740
0
    break;
741
68
  }
742
68
  if (!Opc)
743
0
    return false;
744
745
68
  MatchInfo.first = Opc;
746
68
  MatchInfo.second = *LaneIdx;
747
68
  return true;
748
68
}
749
750
void applyDupLane(MachineInstr &MI, MachineRegisterInfo &MRI,
751
68
                  MachineIRBuilder &B, std::pair<unsigned, int> &MatchInfo) {
752
68
  assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR);
753
0
  Register Src1Reg = MI.getOperand(1).getReg();
754
68
  const LLT SrcTy = MRI.getType(Src1Reg);
755
756
68
  B.setInstrAndDebugLoc(MI);
757
68
  auto Lane = B.buildConstant(LLT::scalar(64), MatchInfo.second);
758
759
68
  Register DupSrc = MI.getOperand(1).getReg();
760
  // For types like <2 x s32>, we can use G_DUPLANE32, with a <4 x s32> source.
761
  // To do this, we can use a G_CONCAT_VECTORS to do the widening.
762
68
  if (SrcTy.getSizeInBits() == 64) {
763
40
    auto Undef = B.buildUndef(SrcTy);
764
40
    DupSrc = B.buildConcatVectors(SrcTy.multiplyElements(2),
765
40
                                  {Src1Reg, Undef.getReg(0)})
766
40
                 .getReg(0);
767
40
  }
768
68
  B.buildInstr(MatchInfo.first, {MI.getOperand(0).getReg()}, {DupSrc, Lane});
769
68
  MI.eraseFromParent();
770
68
}
771
772
554
bool matchScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI) {
773
554
  auto &Unmerge = cast<GUnmerge>(MI);
774
554
  Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
775
554
  const LLT SrcTy = MRI.getType(Src1Reg);
776
554
  return SrcTy.isVector() && !SrcTy.isScalable() &&
777
554
         Unmerge.getNumOperands() == (unsigned)SrcTy.getNumElements() + 1;
778
554
}
779
780
void applyScalarizeVectorUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
781
436
                                 MachineIRBuilder &B) {
782
436
  auto &Unmerge = cast<GUnmerge>(MI);
783
436
  Register Src1Reg = Unmerge.getReg(Unmerge.getNumOperands() - 1);
784
436
  const LLT SrcTy = MRI.getType(Src1Reg);
785
436
  assert((SrcTy.isVector() && !SrcTy.isScalable()) &&
786
436
         "Expected a fixed length vector");
787
788
1.63k
  for (int I = 0; I < SrcTy.getNumElements(); ++I)
789
1.19k
    B.buildExtractVectorElementConstant(Unmerge.getReg(I), Src1Reg, I);
790
436
  MI.eraseFromParent();
791
436
}
792
793
7.01k
bool matchBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI) {
794
7.01k
  assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR);
795
0
  auto Splat = getAArch64VectorSplat(MI, MRI);
796
7.01k
  if (!Splat)
797
1.94k
    return false;
798
5.07k
  if (Splat->isReg())
799
3.92k
    return true;
800
  // Later, during selection, we'll try to match imported patterns using
801
  // immAllOnesV and immAllZerosV. These require G_BUILD_VECTOR. Don't lower
802
  // G_BUILD_VECTORs which could match those patterns.
803
1.15k
  int64_t Cst = Splat->getCst();
804
1.15k
  return (Cst != 0 && Cst != -1);
805
5.07k
}
806
807
void applyBuildVectorToDup(MachineInstr &MI, MachineRegisterInfo &MRI,
808
4.32k
                           MachineIRBuilder &B) {
809
4.32k
  B.setInstrAndDebugLoc(MI);
810
4.32k
  B.buildInstr(AArch64::G_DUP, {MI.getOperand(0).getReg()},
811
4.32k
               {MI.getOperand(1).getReg()});
812
4.32k
  MI.eraseFromParent();
813
4.32k
}
814
815
/// \returns how many instructions would be saved by folding a G_ICMP's shift
816
/// and/or extension operations.
817
13.8k
unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
818
  // No instructions to save if there's more than one use or no uses.
819
13.8k
  if (!MRI.hasOneNonDBGUse(CmpOp))
820
4.66k
    return 0;
821
822
  // FIXME: This is duplicated with the selector. (See: selectShiftedRegister)
823
9.24k
  auto IsSupportedExtend = [&](const MachineInstr &MI) {
824
9.24k
    if (MI.getOpcode() == TargetOpcode::G_SEXT_INREG)
825
2.97k
      return true;
826
6.27k
    if (MI.getOpcode() != TargetOpcode::G_AND)
827
3.49k
      return false;
828
2.77k
    auto ValAndVReg =
829
2.77k
        getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI);
830
2.77k
    if (!ValAndVReg)
831
15
      return false;
832
2.76k
    uint64_t Mask = ValAndVReg->Value.getZExtValue();
833
2.76k
    return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
834
2.77k
  };
835
836
9.21k
  MachineInstr *Def = getDefIgnoringCopies(CmpOp, MRI);
837
9.21k
  if (IsSupportedExtend(*Def))
838
3.38k
    return 1;
839
840
5.83k
  unsigned Opc = Def->getOpcode();
841
5.83k
  if (Opc != TargetOpcode::G_SHL && Opc != TargetOpcode::G_ASHR &&
842
5.83k
      Opc != TargetOpcode::G_LSHR)
843
5.79k
    return 0;
844
845
42
  auto MaybeShiftAmt =
846
42
      getIConstantVRegValWithLookThrough(Def->getOperand(2).getReg(), MRI);
847
42
  if (!MaybeShiftAmt)
848
13
    return 0;
849
29
  uint64_t ShiftAmt = MaybeShiftAmt->Value.getZExtValue();
850
29
  MachineInstr *ShiftLHS =
851
29
      getDefIgnoringCopies(Def->getOperand(1).getReg(), MRI);
852
853
  // Check if we can fold an extend and a shift.
854
  // FIXME: This is duplicated with the selector. (See:
855
  // selectArithExtendedRegister)
856
29
  if (IsSupportedExtend(*ShiftLHS))
857
0
    return (ShiftAmt <= 4) ? 2 : 1;
858
859
29
  LLT Ty = MRI.getType(Def->getOperand(0).getReg());
860
29
  if (Ty.isVector())
861
0
    return 0;
862
29
  unsigned ShiftSize = Ty.getSizeInBits();
863
29
  if ((ShiftSize == 32 && ShiftAmt <= 31) ||
864
29
      (ShiftSize == 64 && ShiftAmt <= 63))
865
29
    return 1;
866
0
  return 0;
867
29
}
868
869
/// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
870
/// instruction \p MI.
871
10.1k
bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {
872
10.1k
  assert(MI.getOpcode() == TargetOpcode::G_ICMP);
873
  // Swap the operands if it would introduce a profitable folding opportunity.
874
  // (e.g. a shift + extend).
875
  //
876
  //  For example:
877
  //    lsl     w13, w11, #1
878
  //    cmp     w13, w12
879
  // can be turned into:
880
  //    cmp     w12, w11, lsl #1
881
882
  // Don't swap if there's a constant on the RHS, because we know we can fold
883
  // that.
884
0
  Register RHS = MI.getOperand(3).getReg();
885
10.1k
  auto RHSCst = getIConstantVRegValWithLookThrough(RHS, MRI);
886
10.1k
  if (RHSCst && isLegalArithImmed(RHSCst->Value.getSExtValue()))
887
3.17k
    return false;
888
889
6.94k
  Register LHS = MI.getOperand(2).getReg();
890
6.94k
  auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
891
13.8k
  auto GetRegForProfit = [&](Register Reg) {
892
13.8k
    MachineInstr *Def = getDefIgnoringCopies(Reg, MRI);
893
13.8k
    return isCMN(Def, Pred, MRI) ? Def->getOperand(2).getReg() : Reg;
894
13.8k
  };
895
896
  // Don't have a constant on the RHS. If we swap the LHS and RHS of the
897
  // compare, would we be able to fold more instructions?
898
6.94k
  Register TheLHS = GetRegForProfit(LHS);
899
6.94k
  Register TheRHS = GetRegForProfit(RHS);
900
901
  // If the LHS is more likely to give us a folding opportunity, then swap the
902
  // LHS and RHS.
903
6.94k
  return (getCmpOperandFoldingProfit(TheLHS, MRI) >
904
6.94k
          getCmpOperandFoldingProfit(TheRHS, MRI));
905
10.1k
}
906
907
45
void applySwapICmpOperands(MachineInstr &MI, GISelChangeObserver &Observer) {
908
45
  auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
909
45
  Register LHS = MI.getOperand(2).getReg();
910
45
  Register RHS = MI.getOperand(3).getReg();
911
45
  Observer.changedInstr(MI);
912
45
  MI.getOperand(1).setPredicate(CmpInst::getSwappedPredicate(Pred));
913
45
  MI.getOperand(2).setReg(RHS);
914
45
  MI.getOperand(3).setReg(LHS);
915
45
  Observer.changedInstr(MI);
916
45
}
917
918
/// \returns a function which builds a vector floating point compare instruction
919
/// for a condition code \p CC.
920
/// \param [in] IsZero - True if the comparison is against 0.
921
/// \param [in] NoNans - True if the target has NoNansFPMath.
922
std::function<Register(MachineIRBuilder &)>
923
getVectorFCMP(AArch64CC::CondCode CC, Register LHS, Register RHS, bool IsZero,
924
37
              bool NoNans, MachineRegisterInfo &MRI) {
925
37
  LLT DstTy = MRI.getType(LHS);
926
37
  assert(DstTy.isVector() && "Expected vector types only?");
927
0
  assert(DstTy == MRI.getType(RHS) && "Src and Dst types must match!");
928
0
  switch (CC) {
929
0
  default:
930
0
    llvm_unreachable("Unexpected condition code!");
931
0
  case AArch64CC::NE:
932
0
    return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
933
0
      auto FCmp = IsZero
934
0
                      ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS})
935
0
                      : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS});
936
0
      return MIB.buildNot(DstTy, FCmp).getReg(0);
937
0
    };
938
0
  case AArch64CC::EQ:
939
0
    return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
940
0
      return IsZero
941
0
                 ? MIB.buildInstr(AArch64::G_FCMEQZ, {DstTy}, {LHS}).getReg(0)
942
0
                 : MIB.buildInstr(AArch64::G_FCMEQ, {DstTy}, {LHS, RHS})
943
0
                       .getReg(0);
944
0
    };
945
6
  case AArch64CC::GE:
946
6
    return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
947
6
      return IsZero
948
6
                 ? MIB.buildInstr(AArch64::G_FCMGEZ, {DstTy}, {LHS}).getReg(0)
949
6
                 : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {LHS, RHS})
950
3
                       .getReg(0);
951
6
    };
952
8
  case AArch64CC::GT:
953
8
    return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
954
8
      return IsZero
955
8
                 ? MIB.buildInstr(AArch64::G_FCMGTZ, {DstTy}, {LHS}).getReg(0)
956
8
                 : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {LHS, RHS})
957
0
                       .getReg(0);
958
8
    };
959
14
  case AArch64CC::LS:
960
14
    return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
961
14
      return IsZero
962
14
                 ? MIB.buildInstr(AArch64::G_FCMLEZ, {DstTy}, {LHS}).getReg(0)
963
14
                 : MIB.buildInstr(AArch64::G_FCMGE, {DstTy}, {RHS, LHS})
964
7
                       .getReg(0);
965
14
    };
966
9
  case AArch64CC::MI:
967
9
    return [LHS, RHS, IsZero, DstTy](MachineIRBuilder &MIB) {
968
9
      return IsZero
969
9
                 ? MIB.buildInstr(AArch64::G_FCMLTZ, {DstTy}, {LHS}).getReg(0)
970
9
                 : MIB.buildInstr(AArch64::G_FCMGT, {DstTy}, {RHS, LHS})
971
6
                       .getReg(0);
972
9
    };
973
37
  }
974
37
}
975
976
/// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
977
bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
978
3.71k
                          MachineIRBuilder &MIB) {
979
3.71k
  assert(MI.getOpcode() == TargetOpcode::G_FCMP);
980
0
  const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
981
982
3.71k
  Register Dst = MI.getOperand(0).getReg();
983
3.71k
  LLT DstTy = MRI.getType(Dst);
984
3.71k
  if (!DstTy.isVector() || !ST.hasNEON())
985
3.68k
    return false;
986
35
  Register LHS = MI.getOperand(2).getReg();
987
35
  unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits();
988
35
  if (EltSize == 16 && !ST.hasFullFP16())
989
0
    return false;
990
35
  if (EltSize != 16 && EltSize != 32 && EltSize != 64)
991
0
    return false;
992
993
35
  return true;
994
35
}
995
996
/// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo.
997
void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,
998
35
                          MachineIRBuilder &MIB) {
999
35
  assert(MI.getOpcode() == TargetOpcode::G_FCMP);
1000
0
  const auto &ST = MI.getMF()->getSubtarget<AArch64Subtarget>();
1001
1002
35
  const auto &CmpMI = cast<GFCmp>(MI);
1003
1004
35
  Register Dst = CmpMI.getReg(0);
1005
35
  CmpInst::Predicate Pred = CmpMI.getCond();
1006
35
  Register LHS = CmpMI.getLHSReg();
1007
35
  Register RHS = CmpMI.getRHSReg();
1008
1009
35
  LLT DstTy = MRI.getType(Dst);
1010
1011
35
  auto Splat = getAArch64VectorSplat(*MRI.getVRegDef(RHS), MRI);
1012
1013
  // Compares against 0 have special target-specific pseudos.
1014
35
  bool IsZero = Splat && Splat->isCst() && Splat->getCst() == 0;
1015
1016
35
  bool Invert = false;
1017
35
  AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
1018
35
  if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
1019
    // The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
1020
    // NaN, so equivalent to a == a and doesn't need the two comparisons an
1021
    // "ord" normally would.
1022
0
    RHS = LHS;
1023
0
    IsZero = false;
1024
0
    CC = AArch64CC::EQ;
1025
0
  } else
1026
35
    changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);
1027
1028
  // Instead of having an apply function, just build here to simplify things.
1029
35
  MIB.setInstrAndDebugLoc(MI);
1030
1031
35
  const bool NoNans =
1032
35
      ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath;
1033
1034
35
  auto Cmp = getVectorFCMP(CC, LHS, RHS, IsZero, NoNans, MRI);
1035
35
  Register CmpRes;
1036
35
  if (CC2 == AArch64CC::AL)
1037
33
    CmpRes = Cmp(MIB);
1038
2
  else {
1039
2
    auto Cmp2 = getVectorFCMP(CC2, LHS, RHS, IsZero, NoNans, MRI);
1040
2
    auto Cmp2Dst = Cmp2(MIB);
1041
2
    auto Cmp1Dst = Cmp(MIB);
1042
2
    CmpRes = MIB.buildOr(DstTy, Cmp1Dst, Cmp2Dst).getReg(0);
1043
2
  }
1044
35
  if (Invert)
1045
13
    CmpRes = MIB.buildNot(DstTy, CmpRes).getReg(0);
1046
35
  MRI.replaceRegWith(Dst, CmpRes);
1047
35
  MI.eraseFromParent();
1048
35
}
1049
1050
bool matchFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1051
39.3k
                         Register &SrcReg) {
1052
39.3k
  assert(MI.getOpcode() == TargetOpcode::G_STORE);
1053
0
  Register DstReg = MI.getOperand(0).getReg();
1054
39.3k
  if (MRI.getType(DstReg).isVector())
1055
660
    return false;
1056
  // Match a store of a truncate.
1057
38.7k
  if (!mi_match(DstReg, MRI, m_GTrunc(m_Reg(SrcReg))))
1058
33.9k
    return false;
1059
  // Only form truncstores for value types of max 64b.
1060
4.76k
  return MRI.getType(SrcReg).getSizeInBits() <= 64;
1061
38.7k
}
1062
1063
void applyFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,
1064
                         MachineIRBuilder &B, GISelChangeObserver &Observer,
1065
4.76k
                         Register &SrcReg) {
1066
4.76k
  assert(MI.getOpcode() == TargetOpcode::G_STORE);
1067
0
  Observer.changingInstr(MI);
1068
4.76k
  MI.getOperand(0).setReg(SrcReg);
1069
4.76k
  Observer.changedInstr(MI);
1070
4.76k
}
1071
1072
// Lower vector G_SEXT_INREG back to shifts for selection. We allowed them to
1073
// form in the first place for combine opportunities, so any remaining ones
1074
// at this stage need be lowered back.
1075
9.30k
bool matchVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI) {
1076
9.30k
  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1077
0
  Register DstReg = MI.getOperand(0).getReg();
1078
9.30k
  LLT DstTy = MRI.getType(DstReg);
1079
9.30k
  return DstTy.isVector();
1080
9.30k
}
1081
1082
void applyVectorSextInReg(MachineInstr &MI, MachineRegisterInfo &MRI,
1083
122
                          MachineIRBuilder &B, GISelChangeObserver &Observer) {
1084
122
  assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG);
1085
0
  B.setInstrAndDebugLoc(MI);
1086
122
  LegalizerHelper Helper(*MI.getMF(), Observer, B);
1087
122
  Helper.lower(MI, 0, /* Unused hint type */ LLT());
1088
122
}
1089
1090
/// Combine <N x t>, unused = unmerge(G_EXT <2*N x t> v, undef, N)
1091
///           => unused, <N x t> = unmerge v
1092
bool matchUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1093
554
                              Register &MatchInfo) {
1094
554
  auto &Unmerge = cast<GUnmerge>(MI);
1095
554
  if (Unmerge.getNumDefs() != 2)
1096
163
    return false;
1097
391
  if (!MRI.use_nodbg_empty(Unmerge.getReg(1)))
1098
309
    return false;
1099
1100
82
  LLT DstTy = MRI.getType(Unmerge.getReg(0));
1101
82
  if (!DstTy.isVector())
1102
2
    return false;
1103
1104
80
  MachineInstr *Ext = getOpcodeDef(AArch64::G_EXT, Unmerge.getSourceReg(), MRI);
1105
80
  if (!Ext)
1106
80
    return false;
1107
1108
0
  Register ExtSrc1 = Ext->getOperand(1).getReg();
1109
0
  Register ExtSrc2 = Ext->getOperand(2).getReg();
1110
0
  auto LowestVal =
1111
0
      getIConstantVRegValWithLookThrough(Ext->getOperand(3).getReg(), MRI);
1112
0
  if (!LowestVal || LowestVal->Value.getZExtValue() != DstTy.getSizeInBytes())
1113
0
    return false;
1114
1115
0
  if (!getOpcodeDef<GImplicitDef>(ExtSrc2, MRI))
1116
0
    return false;
1117
1118
0
  MatchInfo = ExtSrc1;
1119
0
  return true;
1120
0
}
1121
1122
void applyUnmergeExtToUnmerge(MachineInstr &MI, MachineRegisterInfo &MRI,
1123
                              MachineIRBuilder &B,
1124
0
                              GISelChangeObserver &Observer, Register &SrcReg) {
1125
0
  Observer.changingInstr(MI);
1126
  // Swap dst registers.
1127
0
  Register Dst1 = MI.getOperand(0).getReg();
1128
0
  MI.getOperand(0).setReg(MI.getOperand(1).getReg());
1129
0
  MI.getOperand(1).setReg(Dst1);
1130
0
  MI.getOperand(2).setReg(SrcReg);
1131
0
  Observer.changedInstr(MI);
1132
0
}
1133
1134
// Match mul({z/s}ext , {z/s}ext) => {u/s}mull OR
1135
// Match v2s64 mul instructions, which will then be scalarised later on
1136
// Doing these two matches in one function to ensure that the order of matching
1137
// will always be the same.
1138
// Try lowering MUL to MULL before trying to scalarize if needed.
1139
2.49k
bool matchExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI) {
1140
  // Get the instructions that defined the source operand
1141
2.49k
  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1142
2.49k
  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1143
2.49k
  MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1144
1145
2.49k
  if (DstTy.isVector()) {
1146
    // If the source operands were EXTENDED before, then {U/S}MULL can be used
1147
385
    unsigned I1Opc = I1->getOpcode();
1148
385
    unsigned I2Opc = I2->getOpcode();
1149
385
    if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1150
385
         (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1151
385
        (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1152
0
         MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1153
385
        (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1154
0
         MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1155
0
      return true;
1156
0
    }
1157
    // If result type is v2s64, scalarise the instruction
1158
385
    else if (DstTy == LLT::fixed_vector(2, 64)) {
1159
37
      return true;
1160
37
    }
1161
385
  }
1162
2.45k
  return false;
1163
2.49k
}
1164
1165
void applyExtMulToMULL(MachineInstr &MI, MachineRegisterInfo &MRI,
1166
37
                       MachineIRBuilder &B, GISelChangeObserver &Observer) {
1167
37
  assert(MI.getOpcode() == TargetOpcode::G_MUL &&
1168
37
         "Expected a G_MUL instruction");
1169
1170
  // Get the instructions that defined the source operand
1171
0
  LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
1172
37
  MachineInstr *I1 = getDefIgnoringCopies(MI.getOperand(1).getReg(), MRI);
1173
37
  MachineInstr *I2 = getDefIgnoringCopies(MI.getOperand(2).getReg(), MRI);
1174
1175
  // If the source operands were EXTENDED before, then {U/S}MULL can be used
1176
37
  unsigned I1Opc = I1->getOpcode();
1177
37
  unsigned I2Opc = I2->getOpcode();
1178
37
  if (((I1Opc == TargetOpcode::G_ZEXT && I2Opc == TargetOpcode::G_ZEXT) ||
1179
37
       (I1Opc == TargetOpcode::G_SEXT && I2Opc == TargetOpcode::G_SEXT)) &&
1180
37
      (MRI.getType(I1->getOperand(0).getReg()).getScalarSizeInBits() ==
1181
0
       MRI.getType(I1->getOperand(1).getReg()).getScalarSizeInBits() * 2) &&
1182
37
      (MRI.getType(I2->getOperand(0).getReg()).getScalarSizeInBits() ==
1183
0
       MRI.getType(I2->getOperand(1).getReg()).getScalarSizeInBits() * 2)) {
1184
1185
0
    B.setInstrAndDebugLoc(MI);
1186
0
    B.buildInstr(I1->getOpcode() == TargetOpcode::G_ZEXT ? AArch64::G_UMULL
1187
0
                                                         : AArch64::G_SMULL,
1188
0
                 {MI.getOperand(0).getReg()},
1189
0
                 {I1->getOperand(1).getReg(), I2->getOperand(1).getReg()});
1190
0
    MI.eraseFromParent();
1191
0
  }
1192
  // If result type is v2s64, scalarise the instruction
1193
37
  else if (DstTy == LLT::fixed_vector(2, 64)) {
1194
37
    LegalizerHelper Helper(*MI.getMF(), Observer, B);
1195
37
    B.setInstrAndDebugLoc(MI);
1196
37
    Helper.fewerElementsVector(
1197
37
        MI, 0,
1198
37
        DstTy.changeElementCount(
1199
37
            DstTy.getElementCount().divideCoefficientBy(2)));
1200
37
  }
1201
37
}
1202
1203
class AArch64PostLegalizerLoweringImpl : public Combiner {
1204
protected:
1205
  // TODO: Make CombinerHelper methods const.
1206
  mutable CombinerHelper Helper;
1207
  const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig;
1208
  const AArch64Subtarget &STI;
1209
1210
public:
1211
  AArch64PostLegalizerLoweringImpl(
1212
      MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
1213
      GISelCSEInfo *CSEInfo,
1214
      const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1215
      const AArch64Subtarget &STI);
1216
1217
0
  static const char *getName() { return "AArch6400PreLegalizerCombiner"; }
1218
1219
  bool tryCombineAll(MachineInstr &I) const override;
1220
1221
private:
1222
#define GET_GICOMBINER_CLASS_MEMBERS
1223
#include "AArch64GenPostLegalizeGILowering.inc"
1224
#undef GET_GICOMBINER_CLASS_MEMBERS
1225
};
1226
1227
#define GET_GICOMBINER_IMPL
1228
#include "AArch64GenPostLegalizeGILowering.inc"
1229
#undef GET_GICOMBINER_IMPL
1230
1231
AArch64PostLegalizerLoweringImpl::AArch64PostLegalizerLoweringImpl(
1232
    MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC,
1233
    GISelCSEInfo *CSEInfo,
1234
    const AArch64PostLegalizerLoweringImplRuleConfig &RuleConfig,
1235
    const AArch64Subtarget &STI)
1236
    : Combiner(MF, CInfo, TPC, /*KB*/ nullptr, CSEInfo),
1237
      Helper(Observer, B, /*IsPreLegalize*/ true), RuleConfig(RuleConfig),
1238
      STI(STI),
1239
#define GET_GICOMBINER_CONSTRUCTOR_INITS
1240
#include "AArch64GenPostLegalizeGILowering.inc"
1241
#undef GET_GICOMBINER_CONSTRUCTOR_INITS
1242
13.7k
{
1243
13.7k
}
1244
1245
class AArch64PostLegalizerLowering : public MachineFunctionPass {
1246
public:
1247
  static char ID;
1248
1249
  AArch64PostLegalizerLowering();
1250
1251
381
  StringRef getPassName() const override {
1252
381
    return "AArch64PostLegalizerLowering";
1253
381
  }
1254
1255
  bool runOnMachineFunction(MachineFunction &MF) override;
1256
  void getAnalysisUsage(AnalysisUsage &AU) const override;
1257
1258
private:
1259
  AArch64PostLegalizerLoweringImplRuleConfig RuleConfig;
1260
};
1261
} // end anonymous namespace
1262
1263
381
void AArch64PostLegalizerLowering::getAnalysisUsage(AnalysisUsage &AU) const {
1264
381
  AU.addRequired<TargetPassConfig>();
1265
381
  AU.setPreservesCFG();
1266
381
  getSelectionDAGFallbackAnalysisUsage(AU);
1267
381
  MachineFunctionPass::getAnalysisUsage(AU);
1268
381
}
1269
1270
AArch64PostLegalizerLowering::AArch64PostLegalizerLowering()
1271
381
    : MachineFunctionPass(ID) {
1272
381
  initializeAArch64PostLegalizerLoweringPass(*PassRegistry::getPassRegistry());
1273
1274
381
  if (!RuleConfig.parseCommandLineOption())
1275
0
    report_fatal_error("Invalid rule identifier");
1276
381
}
1277
1278
14.8k
bool AArch64PostLegalizerLowering::runOnMachineFunction(MachineFunction &MF) {
1279
14.8k
  if (MF.getProperties().hasProperty(
1280
14.8k
          MachineFunctionProperties::Property::FailedISel))
1281
1.17k
    return false;
1282
13.7k
  assert(MF.getProperties().hasProperty(
1283
13.7k
             MachineFunctionProperties::Property::Legalized) &&
1284
13.7k
         "Expected a legalized function?");
1285
0
  auto *TPC = &getAnalysis<TargetPassConfig>();
1286
13.7k
  const Function &F = MF.getFunction();
1287
1288
13.7k
  const AArch64Subtarget &ST = MF.getSubtarget<AArch64Subtarget>();
1289
13.7k
  CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false,
1290
13.7k
                     /*LegalizerInfo*/ nullptr, /*OptEnabled=*/true,
1291
13.7k
                     F.hasOptSize(), F.hasMinSize());
1292
13.7k
  AArch64PostLegalizerLoweringImpl Impl(MF, CInfo, TPC, /*CSEInfo*/ nullptr,
1293
13.7k
                                        RuleConfig, ST);
1294
13.7k
  return Impl.combineMachineInstrs();
1295
14.8k
}
1296
1297
char AArch64PostLegalizerLowering::ID = 0;
1298
62
INITIALIZE_PASS_BEGIN(AArch64PostLegalizerLowering, DEBUG_TYPE,
1299
62
                      "Lower AArch64 MachineInstrs after legalization", false,
1300
62
                      false)
1301
62
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1302
62
INITIALIZE_PASS_END(AArch64PostLegalizerLowering, DEBUG_TYPE,
1303
                    "Lower AArch64 MachineInstrs after legalization", false,
1304
                    false)
1305
1306
namespace llvm {
1307
381
FunctionPass *createAArch64PostLegalizerLowering() {
1308
381
  return new AArch64PostLegalizerLowering();
1309
381
}
1310
} // end namespace llvm