|
| 1 | +//===-- X86FastTileConfig.cpp - Fast Tile Register Configure---------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +/// \file Pass to config the shape of AMX physical registers |
| 10 | +/// AMX register need to be configured before use. Before FastRegAllocation pass |
| 11 | +/// the ldtilecfg instruction is inserted, however at that time we don't |
| 12 | +/// know the shape of each physical tile registers, because the register |
| 13 | +/// allocation is not done yet. This pass runs after egister allocation |
| 14 | +/// pass. It collects the shape information of each physical tile register |
| 15 | +/// and store the shape in the stack slot that is allocated for load config |
| 16 | +/// to tile config register. |
| 17 | +// |
| 18 | +//===----------------------------------------------------------------------===// |
| 19 | + |
| 20 | +#include "X86.h" |
| 21 | +#include "X86InstrBuilder.h" |
| 22 | +#include "X86MachineFunctionInfo.h" |
| 23 | +#include "X86RegisterInfo.h" |
| 24 | +#include "X86Subtarget.h" |
| 25 | +#include "llvm/CodeGen/MachineFrameInfo.h" |
| 26 | +#include "llvm/CodeGen/MachineFunctionPass.h" |
| 27 | +#include "llvm/CodeGen/MachineInstr.h" |
| 28 | +#include "llvm/CodeGen/MachineRegisterInfo.h" |
| 29 | +#include "llvm/CodeGen/Passes.h" |
| 30 | +#include "llvm/CodeGen/TargetInstrInfo.h" |
| 31 | +#include "llvm/CodeGen/TargetRegisterInfo.h" |
| 32 | +#include "llvm/InitializePasses.h" |
| 33 | + |
| 34 | +using namespace llvm; |
| 35 | + |
| 36 | +#define DEBUG_TYPE "fasttileconfig" |
| 37 | + |
| 38 | +namespace { |
| 39 | + |
| 40 | +class X86FastTileConfig : public MachineFunctionPass { |
| 41 | + // context |
| 42 | + MachineFunction *MF = nullptr; |
| 43 | + const X86Subtarget *ST = nullptr; |
| 44 | + const TargetRegisterInfo *TRI = nullptr; |
| 45 | + const TargetInstrInfo *TII = nullptr; |
| 46 | + MachineRegisterInfo *MRI = nullptr; |
| 47 | + |
| 48 | + MachineInstr *getTileConfigPoint(); |
| 49 | + void tileConfig(); |
| 50 | + |
| 51 | +public: |
| 52 | + X86FastTileConfig() : MachineFunctionPass(ID) {} |
| 53 | + |
| 54 | + bool fastTileConfig(); |
| 55 | + bool isTileLoad(MachineInstr &MI); |
| 56 | + bool isTileStore(MachineInstr &MI); |
| 57 | + bool isAMXInstr(MachineInstr &MI); |
| 58 | + void getTileStoreShape(MachineInstr &MI, |
| 59 | + SmallVector<MachineOperand *> &ShapedTiles); |
| 60 | + |
| 61 | + MachineInstr *getKeyAMXInstr(MachineInstr *MI); |
| 62 | + void getTileShapesCfg(MachineInstr *MI, |
| 63 | + SmallVector<MachineOperand *> &ShapedTiles); |
| 64 | + void getShapeCfgInstrs(MachineInstr *MI, |
| 65 | + std::map<unsigned, MachineInstr *> &RowCfgs, |
| 66 | + std::map<unsigned, MachineInstr *> &ColCfgs); |
| 67 | + |
| 68 | + /// Return the pass name. |
| 69 | + StringRef getPassName() const override { |
| 70 | + return "Fast Tile Register Configure"; |
| 71 | + } |
| 72 | + |
| 73 | + void materializeTileCfg(MachineInstr *MI); |
| 74 | + |
| 75 | + void rewriteTileCfg(SmallVector<MachineOperand *> &ShapedTiles, |
| 76 | + std::map<unsigned, MachineInstr *> &RowCfgs, |
| 77 | + std::map<unsigned, MachineInstr *> &ColCfgs); |
| 78 | + |
| 79 | + /// Perform register allocation. |
| 80 | + bool runOnMachineFunction(MachineFunction &MFunc) override; |
| 81 | + |
| 82 | + MachineFunctionProperties getRequiredProperties() const override { |
| 83 | + return MachineFunctionProperties().set( |
| 84 | + MachineFunctionProperties::Property::NoPHIs); |
| 85 | + } |
| 86 | + |
| 87 | + static char ID; |
| 88 | +}; |
| 89 | + |
| 90 | +} // end anonymous namespace |
| 91 | + |
| 92 | +char X86FastTileConfig::ID = 0; |
| 93 | + |
| 94 | +INITIALIZE_PASS_BEGIN(X86FastTileConfig, DEBUG_TYPE, |
| 95 | + "Fast Tile Register Configure", false, false) |
| 96 | +INITIALIZE_PASS_END(X86FastTileConfig, DEBUG_TYPE, |
| 97 | + "Fast Tile Register Configure", false, false) |
| 98 | + |
| 99 | +static bool isTilePhysReg(MachineOperand &Op) { |
| 100 | + if (!Op.isReg()) |
| 101 | + return false; |
| 102 | + |
| 103 | + Register Reg = Op.getReg(); |
| 104 | + if (Reg >= X86::TMM0 && Reg <= X86::TMM7) |
| 105 | + return true; |
| 106 | + return false; |
| 107 | +} |
| 108 | + |
| 109 | +static unsigned getTilePhysRegIdx(MachineOperand *Op) { |
| 110 | + assert(isTilePhysReg(*Op) && "Tile Operand is invalid"); |
| 111 | + return Op->getReg() - X86::TMM0; |
| 112 | +} |
| 113 | + |
| 114 | +static inline void adjustRowCfg(unsigned TIdx, MachineInstr *MI) { |
| 115 | + unsigned Offset = 48 + TIdx; |
| 116 | + MI->getOperand(3).ChangeToImmediate(Offset); |
| 117 | +} |
| 118 | + |
| 119 | +static inline void adjustColCfg(unsigned TIdx, MachineInstr *MI) { |
| 120 | + unsigned Offset = 16 + TIdx * 2; |
| 121 | + MI->getOperand(3).ChangeToImmediate(Offset); |
| 122 | +} |
| 123 | + |
| 124 | +bool X86FastTileConfig::isTileLoad(MachineInstr &MI) { |
| 125 | + return MI.getOpcode() == X86::PTILELOADDV; |
| 126 | +} |
| 127 | +bool X86FastTileConfig::isTileStore(MachineInstr &MI) { |
| 128 | + return MI.getOpcode() == X86::PTILESTOREDV; |
| 129 | +} |
| 130 | +bool X86FastTileConfig::isAMXInstr(MachineInstr &MI) { |
| 131 | + // TODO: May need to handle some special nontile amx instrucion. |
| 132 | + if (MI.getOpcode() == X86::PLDTILECFGV || MI.isDebugInstr()) |
| 133 | + return false; |
| 134 | + |
| 135 | + for (MachineOperand &MO : MI.operands()) |
| 136 | + if (isTilePhysReg(MO)) |
| 137 | + return true; |
| 138 | + |
| 139 | + return false; |
| 140 | +} |
| 141 | + |
| 142 | +MachineInstr *X86FastTileConfig::getKeyAMXInstr(MachineInstr *MI) { |
| 143 | + auto Cfg = MachineBasicBlock::iterator(MI); |
| 144 | + MachineBasicBlock *MBB = MI->getParent(); |
| 145 | + MachineInstr *KeyMI = nullptr; |
| 146 | + int KeyAMXNum = 0; |
| 147 | + |
| 148 | + for (auto II = Cfg; II != MBB->end(); II++) { |
| 149 | + if (isTileLoad(*II)) { |
| 150 | + KeyMI = &*II; |
| 151 | + continue; |
| 152 | + } |
| 153 | + |
| 154 | + if (isTileStore(*II)) { |
| 155 | + assert(KeyMI && "Key AMX Should be found before!"); |
| 156 | + break; |
| 157 | + } |
| 158 | + |
| 159 | + if (isAMXInstr(*II)) { |
| 160 | + assert((KeyAMXNum == 0) && "Too many Key AMX instruction!"); |
| 161 | + KeyAMXNum++; |
| 162 | + KeyMI = &*II; |
| 163 | + } |
| 164 | + } |
| 165 | + assert(KeyMI && "There must be an AMX instruction."); |
| 166 | + return KeyMI; |
| 167 | +} |
| 168 | + |
| 169 | +// Orderly get the tiles in key amx instruction, uses before defs. |
| 170 | +void X86FastTileConfig::getTileShapesCfg( |
| 171 | + MachineInstr *CfgMI, SmallVector<MachineOperand *> &ShapedTiles) { |
| 172 | + MachineInstr *KeyMI = getKeyAMXInstr(CfgMI); |
| 173 | + |
| 174 | + SmallVector<MachineOperand *> DefTiles; |
| 175 | + for (MachineOperand &MO : KeyMI->operands()) { |
| 176 | + if (!isTilePhysReg(MO)) |
| 177 | + continue; |
| 178 | + if (MO.isDef()) |
| 179 | + DefTiles.push_back(&MO); |
| 180 | + else |
| 181 | + ShapedTiles.push_back(&MO); |
| 182 | + } |
| 183 | + ShapedTiles.append(DefTiles); |
| 184 | +} |
| 185 | + |
| 186 | +// We pre-config the shapes at position named with "amx.tmm.N.shape.row* and |
| 187 | +// amx.shape.N.col*" at pass "Pre AMX Tile Config". |
| 188 | +// The 'N' implies the order of tiles in key amx intrinsic. |
| 189 | +void X86FastTileConfig::getShapeCfgInstrs( |
| 190 | + MachineInstr *MI, std::map<unsigned, MachineInstr *> &RowCfgs, |
| 191 | + std::map<unsigned, MachineInstr *> &ColCfgs) { |
| 192 | + auto Cfg = MachineBasicBlock::iterator(MI); |
| 193 | + MachineBasicBlock *MBB = MI->getParent(); |
| 194 | + |
| 195 | + for (auto II = Cfg; II != MBB->begin(); II--) { |
| 196 | + if (isAMXInstr(*II) || II->isTerminator() || II->isCall()) |
| 197 | + break; |
| 198 | + if (!II->mayStore() || !II->hasOneMemOperand()) |
| 199 | + continue; |
| 200 | + const Value *MemPtr = II->memoperands()[0]->getValue(); |
| 201 | + if (!MemPtr) |
| 202 | + continue; |
| 203 | + |
| 204 | + StringRef Name = MemPtr->getName(); |
| 205 | + if (!Name.startswith("amx.tmm.")) |
| 206 | + continue; |
| 207 | + |
| 208 | + // Get the 'N'th tile shape config in key amx instruction. |
| 209 | + auto N = Name.find(".shape"); |
| 210 | + StringRef STileIdx = Name.slice(8, N); |
| 211 | + unsigned Idx; |
| 212 | + STileIdx.getAsInteger(10, Idx); |
| 213 | + |
| 214 | + // And related them with their store instructions. |
| 215 | + if (Name.contains("row")) |
| 216 | + RowCfgs[Idx] = &*II; |
| 217 | + else if (Name.contains("col")) |
| 218 | + ColCfgs[Idx] = &*II; |
| 219 | + else |
| 220 | + llvm_unreachable("Invalid tile shape info!"); |
| 221 | + } |
| 222 | + assert((RowCfgs.size() == ColCfgs.size()) && |
| 223 | + "The number of tile row and col must be equal!"); |
| 224 | +} |
| 225 | + |
| 226 | +// Here is the data format for the tile config. |
| 227 | +// 0 palette = 1 now. |
| 228 | +// 1 start_row = 0 now. |
| 229 | +// 2-15 reserved, must be zero |
| 230 | +// 16-17 tile0.colsb Tile 0 bytes per row. |
| 231 | +// 18-19 tile1.colsb Tile 1 bytes per row. |
| 232 | +// 20-21 tile2.colsb Tile 2 bytes per row. |
| 233 | +// ... (sequence continues) |
| 234 | +// 30-31 tile7.colsb Tile 7 bytes per row. |
| 235 | +// 32-47 reserved, must be zero |
| 236 | +// 48 tile0.rows Tile 0 rows. |
| 237 | +// 49 tile1.rows Tile 1 rows. |
| 238 | +// 50 tile2.rows Tile 2 rows. |
| 239 | +// ... (sequence continues) |
| 240 | +// 55 tile7.rows Tile 7 rows. |
| 241 | +// 56-63 reserved, must be zero |
| 242 | +void X86FastTileConfig::rewriteTileCfg( |
| 243 | + SmallVector<MachineOperand *> &ShapedTiles, |
| 244 | + std::map<unsigned, MachineInstr *> &RowCfgs, |
| 245 | + std::map<unsigned, MachineInstr *> &ColCfgs) { |
| 246 | + assert((RowCfgs.size() == ShapedTiles.size()) && |
| 247 | + "The number of tile shapes not equal with the number of tiles!"); |
| 248 | + |
| 249 | + // Orderly get the tiles and adjust the shape config. |
| 250 | + for (unsigned I = 0, E = ShapedTiles.size(); I < E; I++) { |
| 251 | + MachineOperand *MO = ShapedTiles[I]; |
| 252 | + unsigned TmmIdx = getTilePhysRegIdx(MO); |
| 253 | + if (I == TmmIdx) |
| 254 | + continue; |
| 255 | + adjustRowCfg(TmmIdx, RowCfgs[I]); |
| 256 | + adjustColCfg(TmmIdx, ColCfgs[I]); |
| 257 | + } |
| 258 | +} |
| 259 | + |
| 260 | +// We have already preconfig the shapes before fast register allocation at |
| 261 | +// X86PreAMXConfig::preWriteTileCfg(). Now, we have done fast register |
| 262 | +// allocation, the shapes pre-written before may not rightly corresponding |
| 263 | +// to the correct tmm registers, so we need adjust them. |
| 264 | +void X86FastTileConfig::materializeTileCfg(MachineInstr *CfgMI) { |
| 265 | + SmallVector<MachineOperand *> ShapedTiles; |
| 266 | + std::map<unsigned, MachineInstr *> RowCfgs; |
| 267 | + std::map<unsigned, MachineInstr *> ColCfgs; |
| 268 | + |
| 269 | + // Orderly keep the tile uses and def in ShapedTiles; |
| 270 | + getTileShapesCfg(CfgMI, ShapedTiles); |
| 271 | + assert(ShapedTiles.size() && "Not find shapes config!"); |
| 272 | + |
| 273 | + getShapeCfgInstrs(CfgMI, RowCfgs, ColCfgs); |
| 274 | + |
| 275 | + rewriteTileCfg(ShapedTiles, RowCfgs, ColCfgs); |
| 276 | +} |
| 277 | + |
| 278 | +bool X86FastTileConfig::fastTileConfig() { |
| 279 | + bool Changed = false; |
| 280 | + |
| 281 | + for (MachineBasicBlock &MBB : *MF) { |
| 282 | + SmallVector<MachineInstr *, 2> CFGs; |
| 283 | + for (MachineInstr &MI : MBB) |
| 284 | + if (MI.getOpcode() == X86::PLDTILECFGV) |
| 285 | + CFGs.push_back(&MI); |
| 286 | + for (auto *MI : CFGs) |
| 287 | + materializeTileCfg(MI); |
| 288 | + if (!CFGs.empty()) |
| 289 | + Changed = true; |
| 290 | + } |
| 291 | + return Changed; |
| 292 | +} |
| 293 | + |
| 294 | +bool X86FastTileConfig::runOnMachineFunction(MachineFunction &MFunc) { |
| 295 | + MF = &MFunc; |
| 296 | + MRI = &MFunc.getRegInfo(); |
| 297 | + ST = &MFunc.getSubtarget<X86Subtarget>(); |
| 298 | + TRI = ST->getRegisterInfo(); |
| 299 | + TII = MFunc.getSubtarget().getInstrInfo(); |
| 300 | + |
| 301 | + return fastTileConfig(); |
| 302 | +} |
| 303 | + |
| 304 | +FunctionPass *llvm::createX86FastTileConfigPass() { |
| 305 | + return new X86FastTileConfig(); |
| 306 | +} |
0 commit comments