Coverage Report

Created: 2024-01-17 10:31

/src/llvm-project/llvm/lib/Target/X86/X86LowerTileCopy.cpp
Line
Count
Source (jump to first uncovered line)
1
//===-- X86LowerTileCopy.cpp - Expand Tile Copy Instructions---------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines the pass which lower AMX tile copy instructions. Since
10
// there is no tile copy instruction, we need store tile register to stack
11
// and load from stack to another tile register. We need extra GR to hold
12
// the stride, and we need stack slot to hold the tile data register.
13
// We would run this pass after copy propagation, so that we don't miss copy
14
// optimization. And we would run this pass before prolog/epilog insertion,
15
// so that we can allocate stack slot.
16
//
17
//===----------------------------------------------------------------------===//
18
19
#include "X86.h"
20
#include "X86InstrBuilder.h"
21
#include "X86InstrInfo.h"
22
#include "X86Subtarget.h"
23
#include "llvm/CodeGen/MachineBasicBlock.h"
24
#include "llvm/CodeGen/MachineFrameInfo.h"
25
#include "llvm/CodeGen/MachineFunction.h"
26
#include "llvm/CodeGen/MachineFunctionPass.h"
27
#include "llvm/CodeGen/MachineInstr.h"
28
#include "llvm/CodeGen/MachineInstrBuilder.h"
29
#include "llvm/CodeGen/MachineOperand.h"
30
#include "llvm/CodeGen/Passes.h"
31
#include "llvm/IR/DebugLoc.h"
32
#include "llvm/InitializePasses.h"
33
#include "llvm/Support/Debug.h"
34
35
using namespace llvm;
36
37
#define DEBUG_TYPE "x86-lower-tile-copy"
38
39
namespace {
40
41
class X86LowerTileCopy : public MachineFunctionPass {
42
public:
43
  static char ID;
44
45
662
  X86LowerTileCopy() : MachineFunctionPass(ID) {}
46
47
  void getAnalysisUsage(AnalysisUsage &AU) const override;
48
49
  bool runOnMachineFunction(MachineFunction &MF) override;
50
51
662
  StringRef getPassName() const override { return "X86 Lower Tile Copy"; }
52
};
53
54
} // namespace
55
56
char X86LowerTileCopy::ID = 0;
57
58
62
INITIALIZE_PASS_BEGIN(X86LowerTileCopy, "lowertilecopy", "Tile Copy Lowering",
59
62
                      false, false)
60
62
INITIALIZE_PASS_END(X86LowerTileCopy, "lowertilecopy", "Tile Copy Lowering",
61
                    false, false)
62
63
662
void X86LowerTileCopy::getAnalysisUsage(AnalysisUsage &AU) const {
64
662
  AU.setPreservesAll();
65
662
  MachineFunctionPass::getAnalysisUsage(AU);
66
662
}
67
68
662
FunctionPass *llvm::createX86LowerTileCopyPass() {
69
662
  return new X86LowerTileCopy();
70
662
}
71
72
22.2k
bool X86LowerTileCopy::runOnMachineFunction(MachineFunction &MF) {
73
22.2k
  const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>();
74
22.2k
  const X86InstrInfo *TII = ST.getInstrInfo();
75
22.2k
  bool Changed = false;
76
77
36.3k
  for (MachineBasicBlock &MBB : MF) {
78
211k
    for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
79
211k
      if (!MI.isCopy())
80
193k
        continue;
81
17.6k
      MachineOperand &DstMO = MI.getOperand(0);
82
17.6k
      MachineOperand &SrcMO = MI.getOperand(1);
83
17.6k
      Register SrcReg = SrcMO.getReg();
84
17.6k
      Register DstReg = DstMO.getReg();
85
17.6k
      if (!X86::TILERegClass.contains(DstReg, SrcReg))
86
17.6k
        continue;
87
88
0
      const TargetRegisterInfo *TRI = ST.getRegisterInfo();
89
      // Allocate stack slot for tile register
90
0
      unsigned Size = TRI->getSpillSize(X86::TILERegClass);
91
0
      Align Alignment = TRI->getSpillAlign(X86::TILERegClass);
92
0
      int TileSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment);
93
      // Allocate stack slot for stride register
94
0
      Size = TRI->getSpillSize(X86::GR64RegClass);
95
0
      Alignment = TRI->getSpillAlign(X86::GR64RegClass);
96
0
      int StrideSS = MF.getFrameInfo().CreateSpillStackObject(Size, Alignment);
97
98
      // TODO: Pick a killed regiter to avoid save/reload. There is problem
99
      // to get live interval in this stage.
100
0
      Register GR64Cand = X86::RAX;
101
102
0
      const DebugLoc &DL = MI.getDebugLoc();
103
      // mov %rax (%sp)
104
0
      BuildMI(MBB, MI, DL, TII->get(X86::IMPLICIT_DEF), GR64Cand);
105
0
      addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64mr)), StrideSS)
106
0
          .addReg(GR64Cand);
107
      // mov 64 %rax
108
0
      BuildMI(MBB, MI, DL, TII->get(X86::MOV64ri), GR64Cand).addImm(64);
109
      // tilestored %tmm, (%sp, %idx)
110
0
      unsigned Opc = X86::TILESTORED;
111
0
      MachineInstr *NewMI =
112
0
          addFrameReference(BuildMI(MBB, MI, DL, TII->get(Opc)), TileSS)
113
0
              .addReg(SrcReg, getKillRegState(SrcMO.isKill()));
114
0
      MachineOperand &MO = NewMI->getOperand(2);
115
0
      MO.setReg(GR64Cand);
116
0
      MO.setIsKill(true);
117
      // tileloadd (%sp, %idx), %tmm
118
0
      Opc = X86::TILELOADD;
119
0
      NewMI = addFrameReference(BuildMI(MBB, MI, DL, TII->get(Opc), DstReg),
120
0
                                TileSS);
121
      // restore %rax
122
      // mov (%sp) %rax
123
0
      addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV64rm), GR64Cand),
124
0
                        StrideSS);
125
0
      MI.eraseFromParent();
126
0
      Changed = true;
127
0
    }
128
36.3k
  }
129
22.2k
  return Changed;
130
22.2k
}