Skip to content

Commit 56ca11e

Browse files
committed
[RISCV] Add an MIR pass to replace redundant sext.w instructions with copies.
Function calls and compare instructions tend to cause sext.w instructions to be inserted. If we make good use of W instructions, these operations can often end up being redundant. We don't always detect these during SelectionDAG due to things like phis. There also some cases caused by failure to turn extload into sextload in SelectionDAG. extload selects to LW allowing later sext.ws to become redundant. This patch adds a pass that examines the input of sext.w instructions trying to determine if it is already sign extended. Either by finding a W instruction, other instructions that produce a sign extended result, or looking through instructions that propagate sign bits. It uses a worklist and visited set to search as far back as necessary. Reviewed By: asb, kito-cheng Differential Revision: https://reviews.llvm.org/D116397
1 parent 2ccf0b7 commit 56ca11e

9 files changed

+627
-62
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_llvm_target(RISCVCodeGen
3535
RISCVMergeBaseOffset.cpp
3636
RISCVRegisterBankInfo.cpp
3737
RISCVRegisterInfo.cpp
38+
RISCVSExtWRemoval.cpp
3839
RISCVSubtarget.cpp
3940
RISCVTargetMachine.cpp
4041
RISCVTargetObjectFile.cpp

llvm/lib/Target/RISCV/RISCV.h

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ FunctionPass *createRISCVISelDag(RISCVTargetMachine &TM);
4040
FunctionPass *createRISCVGatherScatterLoweringPass();
4141
void initializeRISCVGatherScatterLoweringPass(PassRegistry &);
4242

43+
FunctionPass *createRISCVSExtWRemovalPass();
44+
void initializeRISCVSExtWRemovalPass(PassRegistry &);
45+
4346
FunctionPass *createRISCVMergeBaseOffsetOptPass();
4447
void initializeRISCVMergeBaseOffsetOptPass(PassRegistry &);
4548

+266
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
//===-------------- RISCVSExtWRemoval.cpp - MI sext.w Removal -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//
9+
// This pass removes unneeded sext.w instructions at the MI level.
10+
//
11+
//===---------------------------------------------------------------------===//
12+
13+
#include "RISCV.h"
14+
#include "RISCVSubtarget.h"
15+
#include "llvm/ADT/Statistic.h"
16+
#include "llvm/CodeGen/MachineFunctionPass.h"
17+
#include "llvm/CodeGen/TargetInstrInfo.h"
18+
19+
using namespace llvm;
20+
21+
#define DEBUG_TYPE "riscv-sextw-removal"
22+
23+
STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
24+
25+
static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
26+
cl::desc("Disable removal of sext.w"),
27+
cl::init(false), cl::Hidden);
28+
namespace {
29+
30+
class RISCVSExtWRemoval : public MachineFunctionPass {
31+
public:
32+
static char ID;
33+
34+
RISCVSExtWRemoval() : MachineFunctionPass(ID) {
35+
initializeRISCVSExtWRemovalPass(*PassRegistry::getPassRegistry());
36+
}
37+
38+
bool runOnMachineFunction(MachineFunction &MF) override;
39+
40+
void getAnalysisUsage(AnalysisUsage &AU) const override {
41+
AU.setPreservesCFG();
42+
MachineFunctionPass::getAnalysisUsage(AU);
43+
}
44+
45+
StringRef getPassName() const override { return "RISCV sext.w Removal"; }
46+
};
47+
48+
} // end anonymous namespace
49+
50+
char RISCVSExtWRemoval::ID = 0;
51+
INITIALIZE_PASS(RISCVSExtWRemoval, DEBUG_TYPE, "RISCV sext.w Removal", false,
52+
false)
53+
54+
FunctionPass *llvm::createRISCVSExtWRemovalPass() {
55+
return new RISCVSExtWRemoval();
56+
}
57+
58+
// This function returns true if the machine instruction always outputs a value
59+
// where bits 63:32 match bit 31.
60+
// TODO: Allocate a bit in TSFlags for the W instructions?
61+
// TODO: Add other W instructions.
62+
static bool isSignExtendingOpW(const MachineInstr &MI) {
63+
switch (MI.getOpcode()) {
64+
case RISCV::LUI:
65+
case RISCV::LW:
66+
case RISCV::ADDW:
67+
case RISCV::ADDIW:
68+
case RISCV::SUBW:
69+
case RISCV::MULW:
70+
case RISCV::SLLW:
71+
case RISCV::SLLIW:
72+
case RISCV::SRAW:
73+
case RISCV::SRAIW:
74+
case RISCV::SRLW:
75+
case RISCV::SRLIW:
76+
case RISCV::DIVW:
77+
case RISCV::DIVUW:
78+
case RISCV::REMW:
79+
case RISCV::REMUW:
80+
case RISCV::ROLW:
81+
case RISCV::RORW:
82+
case RISCV::RORIW:
83+
case RISCV::CLZW:
84+
case RISCV::CTZW:
85+
case RISCV::CPOPW:
86+
case RISCV::FCVT_W_H:
87+
case RISCV::FCVT_WU_H:
88+
case RISCV::FCVT_W_S:
89+
case RISCV::FCVT_WU_S:
90+
case RISCV::FCVT_W_D:
91+
case RISCV::FCVT_WU_D:
92+
// The following aren't W instructions, but are either sign extended from a
93+
// smaller size or put zeros in bits 63:31.
94+
case RISCV::LBU:
95+
case RISCV::LHU:
96+
case RISCV::LB:
97+
case RISCV::LH:
98+
case RISCV::SEXTB:
99+
case RISCV::SEXTH:
100+
case RISCV::ZEXTH_RV64:
101+
return true;
102+
}
103+
104+
// The LI pattern ADDI rd, X0, imm is sign extended.
105+
if (MI.getOpcode() == RISCV::ADDI && MI.getOperand(1).isReg() &&
106+
MI.getOperand(1).getReg() == RISCV::X0)
107+
return true;
108+
109+
// An ANDI with an 11 bit immediate will zero bits 63:11.
110+
if (MI.getOpcode() == RISCV::ANDI && isUInt<11>(MI.getOperand(2).getImm()))
111+
return true;
112+
113+
// Copying from X0 produces zero.
114+
if (MI.getOpcode() == RISCV::COPY && MI.getOperand(1).getReg() == RISCV::X0)
115+
return true;
116+
117+
return false;
118+
}
119+
120+
static bool isSignExtendedW(const MachineInstr &OrigMI,
121+
MachineRegisterInfo &MRI) {
122+
123+
SmallPtrSet<const MachineInstr *, 4> Visited;
124+
SmallVector<const MachineInstr *, 4> Worklist;
125+
126+
Worklist.push_back(&OrigMI);
127+
128+
while (!Worklist.empty()) {
129+
const MachineInstr *MI = Worklist.pop_back_val();
130+
131+
// If we already visited this instruction, we don't need to check it again.
132+
if (!Visited.insert(MI).second)
133+
continue;
134+
135+
// If this is a sign extending operation we don't need to look any further.
136+
if (isSignExtendingOpW(*MI))
137+
continue;
138+
139+
// Is this an instruction that propagates sign extend.
140+
switch (MI->getOpcode()) {
141+
default:
142+
// Unknown opcode, give up.
143+
return false;
144+
case RISCV::COPY: {
145+
Register SrcReg = MI->getOperand(1).getReg();
146+
147+
// TODO: Handle arguments and returns from calls?
148+
149+
// If this is a copy from another register, check its source instruction.
150+
if (!SrcReg.isVirtual())
151+
return false;
152+
const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
153+
if (!SrcMI)
154+
return false;
155+
156+
// Add SrcMI to the worklist.
157+
Worklist.push_back(SrcMI);
158+
break;
159+
}
160+
case RISCV::ANDI:
161+
case RISCV::ORI:
162+
case RISCV::XORI: {
163+
// Logical operations use a sign extended 12-bit immediate. We just need
164+
// to check if the other operand is sign extended.
165+
Register SrcReg = MI->getOperand(1).getReg();
166+
if (!SrcReg.isVirtual())
167+
return false;
168+
const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
169+
if (!SrcMI)
170+
return false;
171+
172+
// Add SrcMI to the worklist.
173+
Worklist.push_back(SrcMI);
174+
break;
175+
}
176+
case RISCV::AND:
177+
case RISCV::OR:
178+
case RISCV::XOR:
179+
case RISCV::ANDN:
180+
case RISCV::ORN:
181+
case RISCV::XNOR:
182+
case RISCV::MAX:
183+
case RISCV::MAXU:
184+
case RISCV::MIN:
185+
case RISCV::MINU:
186+
case RISCV::PHI: {
187+
// If all incoming values are sign-extended, the output of AND, OR, XOR,
188+
// MIN, MAX, or PHI is also sign-extended.
189+
190+
// The input registers for PHI are operand 1, 3, ...
191+
// The input registers for others are operand 1 and 2.
192+
unsigned E = 3, D = 1;
193+
if (MI->getOpcode() == RISCV::PHI) {
194+
E = MI->getNumOperands();
195+
D = 2;
196+
}
197+
198+
for (unsigned I = 1; I != E; I += D) {
199+
if (!MI->getOperand(I).isReg())
200+
return false;
201+
202+
Register SrcReg = MI->getOperand(I).getReg();
203+
if (!SrcReg.isVirtual())
204+
return false;
205+
const MachineInstr *SrcMI = MRI.getVRegDef(SrcReg);
206+
if (!SrcMI)
207+
return false;
208+
209+
// Add SrcMI to the worklist.
210+
Worklist.push_back(SrcMI);
211+
}
212+
213+
break;
214+
}
215+
}
216+
}
217+
218+
// If we get here, then every node we visited produces a sign extended value
219+
// or propagated sign extended values. So the result must be sign extended.
220+
return true;
221+
}
222+
223+
bool RISCVSExtWRemoval::runOnMachineFunction(MachineFunction &MF) {
224+
if (skipFunction(MF.getFunction()) || DisableSExtWRemoval)
225+
return false;
226+
227+
MachineRegisterInfo &MRI = MF.getRegInfo();
228+
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
229+
230+
if (!ST.is64Bit())
231+
return false;
232+
233+
bool MadeChange = false;
234+
for (MachineBasicBlock &MBB : MF) {
235+
for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) {
236+
MachineInstr *MI = &*I++;
237+
238+
// We're looking for the sext.w pattern ADDIW rd, rs1, 0.
239+
if (MI->getOpcode() != RISCV::ADDIW || !MI->getOperand(2).isImm() ||
240+
MI->getOperand(2).getImm() != 0 || !MI->getOperand(1).isReg())
241+
continue;
242+
243+
// Input should be a virtual register.
244+
Register SrcReg = MI->getOperand(1).getReg();
245+
if (!SrcReg.isVirtual())
246+
continue;
247+
248+
const MachineInstr &SrcMI = *MRI.getVRegDef(SrcReg);
249+
if (!isSignExtendedW(SrcMI, MRI))
250+
continue;
251+
252+
Register DstReg = MI->getOperand(0).getReg();
253+
if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
254+
continue;
255+
256+
LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
257+
MRI.replaceRegWith(DstReg, SrcReg);
258+
MRI.clearKillFlags(SrcReg);
259+
MI->eraseFromParent();
260+
++NumRemovedSExtW;
261+
MadeChange = true;
262+
}
263+
}
264+
265+
return MadeChange;
266+
}

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
3939
initializeGlobalISel(*PR);
4040
initializeRISCVGatherScatterLoweringPass(*PR);
4141
initializeRISCVMergeBaseOffsetOptPass(*PR);
42+
initializeRISCVSExtWRemovalPass(*PR);
4243
initializeRISCVExpandPseudoPass(*PR);
4344
initializeRISCVInsertVSETVLIPass(*PR);
4445
}
@@ -140,6 +141,7 @@ class RISCVPassConfig : public TargetPassConfig {
140141
void addPreEmitPass() override;
141142
void addPreEmitPass2() override;
142143
void addPreSched2() override;
144+
void addMachineSSAOptimization() override;
143145
void addPreRegAlloc() override;
144146
};
145147
} // namespace
@@ -194,6 +196,13 @@ void RISCVPassConfig::addPreEmitPass2() {
194196
addPass(createRISCVExpandAtomicPseudoPass());
195197
}
196198

199+
void RISCVPassConfig::addMachineSSAOptimization() {
200+
TargetPassConfig::addMachineSSAOptimization();
201+
202+
if (TM->getTargetTriple().getArch() == Triple::riscv64)
203+
addPass(createRISCVSExtWRemovalPass());
204+
}
205+
197206
void RISCVPassConfig::addPreRegAlloc() {
198207
if (TM->getOptLevel() != CodeGenOpt::None)
199208
addPass(createRISCVMergeBaseOffsetOptPass());

0 commit comments

Comments
 (0)