Skip to content

Commit d4bdeca

Browse files
committed
[X86] Support AMX fast register allocation
Differential Revision: https://reviews.llvm.org/D100026
1 parent 72bd011 commit d4bdeca

24 files changed

+6950
-29
lines changed

clang/include/clang/Basic/BuiltinsX86_64.def

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ TARGET_BUILTIN(__builtin_ia32_testui, "Uc", "n", "uintr")
101101
TARGET_BUILTIN(__builtin_ia32_senduipi, "vUWi", "n", "uintr")
102102

103103
// AMX internal builtin
104+
TARGET_BUILTIN(__builtin_ia32_tile_loadconfig_internal, "vvC*", "n", "amx-tile")
104105
TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
105106
TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
106107
TARGET_BUILTIN(__builtin_ia32_tdpbsud_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")

llvm/include/llvm/CodeGen/Passes.h

+3
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,9 @@ namespace llvm {
507507
/// or split the data to two <128 x i32>.
508508
FunctionPass *createX86LowerAMXTypePass();
509509

510+
/// The pass insert tile config intrinsics for AMX fast register allocation.
511+
FunctionPass *createX86PreAMXConfigPass();
512+
510513
/// The pass transforms amx intrinsics to scalar operation if the function has
511514
/// optnone attribute or it is O0.
512515
FunctionPass *createX86LowerAMXIntrinsicsPass();

llvm/include/llvm/CodeGen/TargetPassConfig.h

+4
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,10 @@ class TargetPassConfig : public ImmutablePass {
406406
return false;
407407
}
408408

409+
/// addPostFastRegAllocRewrite - Add passes to the optimized register
410+
/// allocation pipeline after fast register allocation is complete.
411+
virtual bool addPostFastRegAllocRewrite() { return false; }
412+
409413
/// Add passes to be run immediately after virtual registers are rewritten
410414
/// to physical registers.
411415
virtual void addPostRewrite() { }

llvm/include/llvm/IR/IntrinsicsX86.td

+3
Original file line numberDiff line numberDiff line change
@@ -5042,6 +5042,9 @@ let TargetPrefix = "x86" in {
50425042
[ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>,
50435043
ImmArg<ArgIndex<2>>]>;
50445044
// AMX - internal intrinsics
5045+
def int_x86_ldtilecfg_internal :
5046+
GCCBuiltin<"__builtin_ia32_tile_loadconfig_internal">,
5047+
Intrinsic<[], [llvm_ptr_ty], []>;
50455048
def int_x86_tileloadd64_internal :
50465049
GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
50475050
Intrinsic<[llvm_x86amx_ty],

llvm/lib/CodeGen/TargetPassConfig.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,10 @@ bool TargetPassConfig::addRegAssignAndRewriteFast() {
13211321
report_fatal_error("Must use fast (default) register allocator for unoptimized regalloc.");
13221322

13231323
addPass(createRegAllocPass(false));
1324+
1325+
// Allow targets to change the register assignments after
1326+
// fast register allocation.
1327+
addPostFastRegAllocRewrite();
13241328
return true;
13251329
}
13261330

llvm/lib/Target/X86/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ set(sources
3434
X86DiscriminateMemOps.cpp
3535
X86LowerTileCopy.cpp
3636
X86LowerAMXType.cpp
37+
X86PreAMXConfig.cpp
3738
X86LowerAMXIntrinsics.cpp
3839
X86TileConfig.cpp
40+
X86FastTileConfig.cpp
3941
X86PreTileConfig.cpp
4042
X86ExpandPseudo.cpp
4143
X86FastISel.cpp

llvm/lib/Target/X86/X86.h

+5
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ FunctionPass *createX86WinAllocaExpander();
7979
/// Return a pass that config the tile registers.
8080
FunctionPass *createX86TileConfigPass();
8181

82+
/// Return a pass that config the tile registers after fast reg allocation.
83+
FunctionPass *createX86FastTileConfigPass();
84+
8285
/// Return a pass that insert pseudo tile config instruction.
8386
FunctionPass *createX86PreTileConfigPass();
8487

@@ -172,8 +175,10 @@ void initializeX86PartialReductionPass(PassRegistry &);
172175
void initializeX86SpeculativeLoadHardeningPassPass(PassRegistry &);
173176
void initializeX86SpeculativeExecutionSideEffectSuppressionPass(PassRegistry &);
174177
void initializeX86PreTileConfigPass(PassRegistry &);
178+
void initializeX86FastTileConfigPass(PassRegistry &);
175179
void initializeX86TileConfigPass(PassRegistry &);
176180
void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
181+
void initializeX86PreAMXConfigPassPass(PassRegistry &);
177182
void initializeX86LowerTileCopyPass(PassRegistry &);
178183
void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &);
179184

llvm/lib/Target/X86/X86ExpandPseudo.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,10 @@ bool X86ExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
478478
case TargetOpcode::ICALL_BRANCH_FUNNEL:
479479
ExpandICallBranchFunnel(&MBB, MBBI);
480480
return true;
481+
case X86::PLDTILECFGV: {
482+
MI.setDesc(TII->get(X86::LDTILECFG));
483+
return true;
484+
}
481485
case X86::PTILELOADDV: {
482486
for (unsigned i = 2; i > 0; --i)
483487
MI.RemoveOperand(i);
+306
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
//===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
/// \file Pass to config the shape of AMX physical registers
10+
/// AMX register need to be configured before use. Before FastRegAllocation pass
11+
/// the ldtilecfg instruction is inserted, however at that time we don't
12+
/// know the shape of each physical tile registers, because the register
13+
/// allocation is not done yet. This pass runs after egister allocation
14+
/// pass. It collects the shape information of each physical tile register
15+
/// and store the shape in the stack slot that is allocated for load config
16+
/// to tile config register.
17+
//
18+
//===----------------------------------------------------------------------===//
19+
20+
#include "X86.h"
21+
#include "X86InstrBuilder.h"
22+
#include "X86MachineFunctionInfo.h"
23+
#include "X86RegisterInfo.h"
24+
#include "X86Subtarget.h"
25+
#include "llvm/CodeGen/MachineFrameInfo.h"
26+
#include "llvm/CodeGen/MachineFunctionPass.h"
27+
#include "llvm/CodeGen/MachineInstr.h"
28+
#include "llvm/CodeGen/MachineRegisterInfo.h"
29+
#include "llvm/CodeGen/Passes.h"
30+
#include "llvm/CodeGen/TargetInstrInfo.h"
31+
#include "llvm/CodeGen/TargetRegisterInfo.h"
32+
#include "llvm/InitializePasses.h"
33+
34+
using namespace llvm;
35+
36+
#define DEBUG_TYPE "fasttileconfig"
37+
38+
namespace {
39+
40+
class X86FastTileConfig : public MachineFunctionPass {
41+
// context
42+
MachineFunction *MF = nullptr;
43+
const X86Subtarget *ST = nullptr;
44+
const TargetRegisterInfo *TRI = nullptr;
45+
const TargetInstrInfo *TII = nullptr;
46+
MachineRegisterInfo *MRI = nullptr;
47+
48+
MachineInstr *getTileConfigPoint();
49+
void tileConfig();
50+
51+
public:
52+
X86FastTileConfig() : MachineFunctionPass(ID) {}
53+
54+
bool fastTileConfig();
55+
bool isTileLoad(MachineInstr &MI);
56+
bool isTileStore(MachineInstr &MI);
57+
bool isAMXInstr(MachineInstr &MI);
58+
void getTileStoreShape(MachineInstr &MI,
59+
SmallVector<MachineOperand *> &ShapedTiles);
60+
61+
MachineInstr *getKeyAMXInstr(MachineInstr *MI);
62+
void getTileShapesCfg(MachineInstr *MI,
63+
SmallVector<MachineOperand *> &ShapedTiles);
64+
void getShapeCfgInstrs(MachineInstr *MI,
65+
std::map<unsigned, MachineInstr *> &RowCfgs,
66+
std::map<unsigned, MachineInstr *> &ColCfgs);
67+
68+
/// Return the pass name.
69+
StringRef getPassName() const override {
70+
return "Fast Tile Register Configure";
71+
}
72+
73+
void materializeTileCfg(MachineInstr *MI);
74+
75+
void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles,
76+
std::map<unsigned, MachineInstr *> &RowCfgs,
77+
std::map<unsigned, MachineInstr *> &ColCfgs);
78+
79+
/// Perform register allocation.
80+
bool runOnMachineFunction(MachineFunction &MFunc) override;
81+
82+
MachineFunctionProperties getRequiredProperties() const override {
83+
return MachineFunctionProperties().set(
84+
MachineFunctionProperties::Property::NoPHIs);
85+
}
86+
87+
static char ID;
88+
};
89+
90+
} // end anonymous namespace
91+
92+
char X86FastTileConfig::ID = 0;
93+
94+
INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE,
95+
"Fast Tile Register Configure", false, false)
96+
INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE,
97+
"Fast Tile Register Configure", false, false)
98+
99+
static bool isTilePhysReg(MachineOperand &Op) {
100+
if (!Op.isReg())
101+
return false;
102+
103+
Register Reg = Op.getReg();
104+
if (Reg >= X86::TMM0 && Reg <= X86::TMM7)
105+
return true;
106+
return false;
107+
}
108+
109+
static unsigned getTilePhysRegIdx(MachineOperand *Op) {
110+
assert(isTilePhysReg(*Op) && "Tile Operand is invalid");
111+
return Op->getReg() - X86::TMM0;
112+
}
113+
114+
static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) {
115+
unsigned Offset = 48 + TIdx;
116+
MI->getOperand(3).ChangeToImmediate(Offset);
117+
}
118+
119+
static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) {
120+
unsigned Offset = 16 + TIdx * 2;
121+
MI->getOperand(3).ChangeToImmediate(Offset);
122+
}
123+
124+
bool X86FastTileConfig::isTileLoad(MachineInstr &MI) {
125+
return MI.getOpcode() == X86::PTILELOADDV;
126+
}
127+
bool X86FastTileConfig::isTileStore(MachineInstr &MI) {
128+
return MI.getOpcode() == X86::PTILESTOREDV;
129+
}
130+
bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) {
131+
// TODO: May need to handle some special nontile amx instrucion.
132+
if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr())
133+
return false;
134+
135+
for (MachineOperand &MO : MI.operands())
136+
if (isTilePhysReg(MO))
137+
return true;
138+
139+
return false;
140+
}
141+
142+
MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) {
143+
auto Cfg = MachineBasicBlock::iterator(MI);
144+
MachineBasicBlock *MBB = MI->getParent();
145+
MachineInstr *KeyMI = nullptr;
146+
int KeyAMXNum = 0;
147+
148+
for (auto II = Cfg; II != MBB->end(); II++) {
149+
if (isTileLoad(*II)) {
150+
KeyMI = &*II;
151+
continue;
152+
}
153+
154+
if (isTileStore(*II)) {
155+
assert(KeyMI && "Key AMX Should be found before!");
156+
break;
157+
}
158+
159+
if (isAMXInstr(*II)) {
160+
assert((KeyAMXNum == 0) && "Too many Key AMX instruction!");
161+
KeyAMXNum++;
162+
KeyMI = &*II;
163+
}
164+
}
165+
assert(KeyMI && "There must be an AMX instruction.");
166+
return KeyMI;
167+
}
168+
169+
// Orderly get the tiles in key amx instruction, uses before defs.
170+
void X86FastTileConfig::getTileShapesCfg(
171+
MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) {
172+
MachineInstr *KeyMI = getKeyAMXInstr(CfgMI);
173+
174+
SmallVector<MachineOperand *> DefTiles;
175+
for (MachineOperand &MO : KeyMI->operands()) {
176+
if (!isTilePhysReg(MO))
177+
continue;
178+
if (MO.isDef())
179+
DefTiles.push_back(&MO);
180+
else
181+
ShapedTiles.push_back(&MO);
182+
}
183+
ShapedTiles.append(DefTiles);
184+
}
185+
186+
// We pre-config the shapes at position named with "amx.tmm.N.shape.row* and
187+
// amx.shape.N.col*" at pass "Pre AMX Tile Config".
188+
// The 'N' implies the order of tiles in key amx intrinsic.
189+
void X86FastTileConfig::getShapeCfgInstrs(
190+
MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs,
191+
std::map<unsigned, MachineInstr *> &ColCfgs) {
192+
auto Cfg = MachineBasicBlock::iterator(MI);
193+
MachineBasicBlock *MBB = MI->getParent();
194+
195+
for (auto II = Cfg; II != MBB->begin(); II--) {
196+
if (isAMXInstr(*II) || II->isTerminator() || II->isCall())
197+
break;
198+
if (!II->mayStore() || !II->hasOneMemOperand())
199+
continue;
200+
const Value *MemPtr = II->memoperands()[0]->getValue();
201+
if (!MemPtr)
202+
continue;
203+
204+
StringRef Name = MemPtr->getName();
205+
if (!Name.startswith("amx.tmm."))
206+
continue;
207+
208+
// Get the 'N'th tile shape config in key amx instruction.
209+
auto N = Name.find(".shape");
210+
StringRef STileIdx = Name.slice(8, N);
211+
unsigned Idx;
212+
STileIdx.getAsInteger(10, Idx);
213+
214+
// And related them with their store instructions.
215+
if (Name.contains("row"))
216+
RowCfgs[Idx] = &*II;
217+
else if (Name.contains("col"))
218+
ColCfgs[Idx] = &*II;
219+
else
220+
llvm_unreachable("Invalid tile shape info!");
221+
}
222+
assert((RowCfgs.size() == ColCfgs.size()) &&
223+
"The number of tile row and col must be equal!");
224+
}
225+
226+
// Here is the data format for the tile config.
227+
// 0 palette = 1 now.
228+
// 1 start_row = 0 now.
229+
// 2-15 reserved, must be zero
230+
// 16-17 tile0.colsb Tile 0 bytes per row.
231+
// 18-19 tile1.colsb Tile 1 bytes per row.
232+
// 20-21 tile2.colsb Tile 2 bytes per row.
233+
// ... (sequence continues)
234+
// 30-31 tile7.colsb Tile 7 bytes per row.
235+
// 32-47 reserved, must be zero
236+
// 48 tile0.rows Tile 0 rows.
237+
// 49 tile1.rows Tile 1 rows.
238+
// 50 tile2.rows Tile 2 rows.
239+
// ... (sequence continues)
240+
// 55 tile7.rows Tile 7 rows.
241+
// 56-63 reserved, must be zero
242+
void X86FastTileConfig::rewriteTileCfg(
243+
SmallVector<MachineOperand *> &ShapedTiles,
244+
std::map<unsigned, MachineInstr *> &RowCfgs,
245+
std::map<unsigned, MachineInstr *> &ColCfgs) {
246+
assert((RowCfgs.size() == ShapedTiles.size()) &&
247+
"The number of tile shapes not equal with the number of tiles!");
248+
249+
// Orderly get the tiles and adjust the shape config.
250+
for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) {
251+
MachineOperand *MO = ShapedTiles[I];
252+
unsigned TmmIdx = getTilePhysRegIdx(MO);
253+
if (I == TmmIdx)
254+
continue;
255+
adjustRowCfg(TmmIdx, RowCfgs[I]);
256+
adjustColCfg(TmmIdx, ColCfgs[I]);
257+
}
258+
}
259+
260+
// We have already preconfig the shapes before fast register allocation at
261+
// X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register
262+
// allocation, the shapes pre-written before may not rightly corresponding
263+
// to the correct tmm registers, so we need adjust them.
264+
void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) {
265+
SmallVector<MachineOperand *> ShapedTiles;
266+
std::map<unsigned, MachineInstr *> RowCfgs;
267+
std::map<unsigned, MachineInstr *> ColCfgs;
268+
269+
// Orderly keep the tile uses and def in ShapedTiles;
270+
getTileShapesCfg(CfgMI, ShapedTiles);
271+
assert(ShapedTiles.size() && "Not find shapes config!");
272+
273+
getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs);
274+
275+
rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs);
276+
}
277+
278+
bool X86FastTileConfig::fastTileConfig() {
279+
bool Changed = false;
280+
281+
for (MachineBasicBlock &MBB : *MF) {
282+
SmallVector<MachineInstr *, 2> CFGs;
283+
for (MachineInstr &MI : MBB)
284+
if (MI.getOpcode() == X86::PLDTILECFGV)
285+
CFGs.push_back(&MI);
286+
for (auto *MI : CFGs)
287+
materializeTileCfg(MI);
288+
if (!CFGs.empty())
289+
Changed = true;
290+
}
291+
return Changed;
292+
}
293+
294+
bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) {
295+
MF = &MFunc;
296+
MRI = &MFunc.getRegInfo();
297+
ST = &MFunc.getSubtarget<X86Subtarget>();
298+
TRI = ST->getRegisterInfo();
299+
TII = MFunc.getSubtarget().getInstrInfo();
300+
301+
return fastTileConfig();
302+
}
303+
304+
FunctionPass *llvm::createX86FastTileConfigPass() {
305+
return new X86FastTileConfig();
306+
}

0 commit comments

Comments
 (0)